diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py
index 4a202786d..57262504d 100755
--- a/.ci/update_windows/update.py
+++ b/.ci/update_windows/update.py
@@ -66,8 +66,10 @@ if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
- print("pulling.") # noqa: T201
- pull(repo)
+ print("fetching.") # noqa: T201
+ for remote in repo.remotes:
+ if remote.name == "origin":
+ remote.fetch()
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
@@ -170,3 +172,4 @@ try:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass
+
diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
index 96a500be2..2cbb00d99 100755
--- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
+++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt
@@ -1,5 +1,5 @@
-As of the time of writing this you need this preview driver for best results:
-https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
+As of the time of writing this you need this driver for best results:
+https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
HOW TO RUN:
@@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
+
diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml
index 9274b4170..d72ece2ce 100644
--- a/.github/workflows/release-stable-all.yml
+++ b/.github/workflows/release-stable-all.yml
@@ -65,11 +65,11 @@ jobs:
contents: "write"
packages: "write"
pull-requests: "read"
- name: "Release AMD ROCm 6.4.4"
+ name: "Release AMD ROCm 7.1.1"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
- cache_tag: "rocm644"
+ cache_tag: "rocm711"
python_minor: "12"
python_patch: "10"
rel_name: "amd"
diff --git a/CODEOWNERS b/CODEOWNERS
index b7aca9b26..4d5448636 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,3 +1,2 @@
# Admins
-* @comfyanonymous
-* @kosinkadink
+* @comfyanonymous @kosinkadink @guill
diff --git a/README.md b/README.md
index b9300ab07..bae955b1b 100644
--- a/README.md
+++ b/README.md
@@ -68,6 +68,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
+ - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@@ -80,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
+ - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@@ -318,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
+
+## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
+
+**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
+
+### Setup
+
+1. Install the manager dependencies:
+ ```bash
+ pip install -r manager_requirements.txt
+ ```
+
+2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
+ ```bash
+ python main.py --enable-manager
+ ```
+
+### Command Line Options
+
+| Flag | Description |
+|------|-------------|
+| `--enable-manager` | Enable ComfyUI-Manager |
+| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
+| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
+
+
# Running
```python main.py```
diff --git a/app/user_manager.py b/app/user_manager.py
index a2d376c0c..e2c00dab2 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -59,6 +59,9 @@ class UserManager():
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
+ # Block System Users (use same error message to prevent probing)
+ if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
+ raise KeyError("Unknown user: " + user)
if user not in self.users:
raise KeyError("Unknown user: " + user)
@@ -66,15 +69,16 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
- user_directory = folder_paths.get_user_directory()
-
if type == "userdata":
- root_dir = user_directory
+ root_dir = folder_paths.get_user_directory()
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
- path = user_root = os.path.abspath(os.path.join(root_dir, user))
+ user_root = folder_paths.get_public_user_directory(user)
+ if user_root is None:
+ return None
+ path = user_root
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
@@ -101,7 +105,11 @@ class UserManager():
name = name.strip()
if not name:
raise ValueError("username not provided")
+ if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
+ raise ValueError("System User prefix not allowed")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
+ if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
+ raise ValueError("System User prefix not allowed")
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
@@ -132,7 +140,10 @@ class UserManager():
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
- user_id = self.add_user(username)
+ try:
+ user_id = self.add_user(username)
+ except ValueError as e:
+ return web.json_response({"error": str(e)}, status=400)
return web.json_response(user_id)
@routes.get("/userdata")
@@ -424,7 +435,7 @@ class UserManager():
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
- if not isinstance(source, str):
+ if not isinstance(dest, str):
return dest
overwrite = request.query.get("overwrite", 'true') != "false"
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 1c8ef0c1f..209fc185b 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -137,7 +137,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
-parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
+parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
+parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
diff --git a/comfy/context_windows.py b/comfy/context_windows.py
index 041f380f9..5c412d1c2 100644
--- a/comfy/context_windows.py
+++ b/comfy/context_windows.py
@@ -51,26 +51,36 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
- def __init__(self, index_list: list[int], dim: int=0):
+ def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
+ self.total_frames = total_frames
+ self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
- def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
+ def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
- idx = [slice(None)] * dim + [self.index_list]
- return full[idx].to(device)
+ idx = tuple([slice(None)] * dim + [self.index_list])
+ window = full[idx]
+ if retain_index_list:
+ idx = tuple([slice(None)] * dim + [retain_index_list])
+ window[idx] = full[idx]
+ return window.to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
- idx = [slice(None)] * dim + [self.index_list]
+ idx = tuple([slice(None)] * dim + [self.index_list])
full[idx] += to_add
return full
+ def get_region_index(self, num_regions: int) -> int:
+ region_idx = int(self.center_ratio * num_regions)
+ return min(max(region_idx, 0), num_regions - 1)
+
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@@ -94,7 +104,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
- def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
+ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
+ closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC):
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
+ self.freenoise = freenoise
+ self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
+ self.split_conds_to_windows = split_conds_to_windows
self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
- logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
+ logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
+ if self.cond_retain_index_list:
+ logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True
return False
@@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
return None
# reuse or resize cond items to match context requirements
resized_cond = []
+ # if multiple conds, split based on primary region
+ if self.split_conds_to_windows and len(cond_in) > 1:
+ region = window.get_region_index(len(cond_in))
+ logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
+ cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
@@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC):
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor):
- if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
+ if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
+ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
+ # Handle audio_embed (temporal dim is 1)
+ elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
+ audio_cond = cond_value.cond
+ if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
+ new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
- if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
- new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
+ if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
+ (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
+ new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
@@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
- mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
+ mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
- context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
+ context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC):
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
- idx_window = [slice(None)] * self.dim + [idx]
- pos_window = [slice(None)] * self.dim + [pos]
+ idx_window = tuple([slice(None)] * self.dim + [idx])
+ pos_window = tuple([slice(None)] * self.dim + [pos])
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
@@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
)
+def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
+ model_options = extra_args.get("model_options", None)
+ if model_options is None:
+ raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
+ handler: IndexListContextHandler = model_options.get("context_handler", None)
+ if handler is None:
+ raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
+ if not handler.freenoise:
+ return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
+ noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
+
+ return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
+
+
+def create_sampler_sample_wrapper(model: ModelPatcher):
+ model.add_wrapper_with_key(
+ comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
+ "ContextWindows_sampler_sample",
+ _sampler_sample_wrapper
+ )
+
+
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta
+
+
+# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
+def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
+ logging.info("Context windows: Applying FreeNoise")
+ generator = torch.Generator(device='cpu').manual_seed(seed)
+ latent_video_length = noise.shape[dim]
+ delta = context_length - context_overlap
+
+ for start_idx in range(0, latent_video_length - context_length, delta):
+ place_idx = start_idx + context_length
+
+ actual_delta = min(delta, latent_video_length - place_idx)
+ if actual_delta <= 0:
+ break
+
+ list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
+
+ source_slice = [slice(None)] * noise.ndim
+ source_slice[dim] = list_idx
+ target_slice = [slice(None)] * noise.ndim
+ target_slice[dim] = slice(place_idx, place_idx + actual_delta)
+
+ noise[tuple(target_slice)] = noise[tuple(source_slice)]
+
+ return noise
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 8e110f45d..f1ca0151e 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -431,6 +431,7 @@ class HunyuanVideo(LatentFormat):
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
+ taesd_decoder_name = "taehv"
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
@@ -494,7 +495,7 @@ class Wan21(LatentFormat):
]).view(1, self.latent_channels, 1, 1, 1)
- self.taesd_decoder_name = None #TODO
+ self.taesd_decoder_name = "lighttaew2_1"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
@@ -565,6 +566,7 @@ class Wan22(Wan21):
def __init__(self):
self.scale_factor = 1.0
+ self.taesd_decoder_name = "lighttaew2_2"
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
@@ -719,6 +721,7 @@ class HunyuanVideo15(LatentFormat):
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
+ taesd_decoder_name = "lighttaehy1_5"
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index a72f8cc47..2e8ef0687 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -40,7 +40,8 @@ class ChromaParams:
out_dim: int
hidden_dim: int
n_layers: int
-
+ txt_ids_dims: list
+ vec_in_dim: int
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 2472ab79c..60f2bdae2 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
+class YakMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, x: Tensor) -> Tensor:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
+ if yak_mlp:
+ return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
+ if mlp_silu_act:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
+ SiLUActivation(),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
+ )
+ else:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+ nn.GELU(approximate="tanh"),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+ )
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- if mlp_silu_act:
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
- SiLUActivation(),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
- )
- else:
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
+ self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
@@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- if mlp_silu_act:
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
- SiLUActivation(),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
- )
- else:
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
+ self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
@@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
modulation=True,
mlp_silu_act=False,
bias=True,
+ yak_mlp=False,
dtype=None,
device=None,
operations=None
@@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
+ self.yak_mlp = yak_mlp
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
+ if self.yak_mlp:
+ self.mlp_hidden_dim_first *= 2
+ self.mlp_act = nn.SiLU()
+
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
@@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
- mlp = self.mlp_act(mlp)
+ if self.yak_mlp:
+ mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
+ else:
+ mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index 1a24e6d95..f40c2a7a9 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -15,7 +15,8 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
- Modulation
+ Modulation,
+ RMSNorm
)
@dataclass
@@ -34,11 +35,14 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
+ txt_ids_dims: list
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
+ yak_mlp: bool = False
+ txt_norm: bool = False
class Flux(nn.Module):
@@ -76,6 +80,11 @@ class Flux(nn.Module):
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
+ if params.txt_norm:
+ self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
+ else:
+ self.txt_norm = None
+
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -86,6 +95,7 @@ class Flux(nn.Module):
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
+ yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -94,7 +104,7 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@@ -150,6 +160,8 @@ class Flux(nn.Module):
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ if self.txt_norm is not None:
+ txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
@@ -171,7 +183,10 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.double_blocks)
+ transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
+ transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -215,7 +230,10 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
+ transformer_options["total_blocks"] = len(self.single_blocks)
+ transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
+ transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@@ -326,8 +344,9 @@ class Flux(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
- if len(self.params.axes_dim) == 4: # Flux 2
- txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
+ if len(self.params.txt_ids_dims) > 0:
+ for i in self.params.txt_ids_dims:
+ txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py
index 9f5e91a59..85f515f67 100644
--- a/comfy/ldm/hunyuan_video/upsampler.py
+++ b/comfy/ldm/hunyuan_video/upsampler.py
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
+from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py
index 9f750dcc4..ddf77cd0e 100644
--- a/comfy/ldm/hunyuan_video/vae_refiner.py
+++ b/comfy/ldm/hunyuan_video/vae_refiner.py
@@ -1,42 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
import comfy.ops
import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init
-class NoPadConv3d(nn.Module):
- def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
- super().__init__()
- self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
-
- def forward(self, x):
- return self.conv(x)
-
-
-def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
-
- x = xl[0]
- xl.clear()
-
- if conv_carry_out is not None:
- to_push = x[:, :, -2:, :, :].clone()
- conv_carry_out.append(to_push)
-
- if isinstance(op, NoPadConv3d):
- if conv_carry_in is None:
- x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
- else:
- carry_len = conv_carry_in[0].shape[2]
- x = torch.cat([conv_carry_in.pop(0), x], dim=2)
- x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
-
- out = op(x)
-
- return out
-
class RMS_norm(nn.Module):
def __init__(self, dim):
@@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module):
- def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
+ def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
@@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
- def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
+ def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
return h + x
-class HunyuanRefinerResnetBlock(ResnetBlock):
- def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
- super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
-
- def forward(self, x, conv_carry_in=None, conv_carry_out=None):
- h = x
- h = [ self.swish(self.norm1(x)) ]
- h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
-
- h = [ self.dropout(self.swish(self.norm2(h))) ]
- h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
-
- if self.in_channels != self.out_channels:
- x = self.nin_shortcut(x)
-
- return x+h
-
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@@ -191,7 +144,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
- conv_op = NoPadConv3d
+ conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -206,9 +159,10 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
- stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
- out_channels=tgt,
- conv_op=conv_op, norm_op=norm_op)
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
@@ -218,9 +172,9 @@ class Encoder(nn.Module):
self.down.append(stage)
self.mid = nn.Module()
- self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
- self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@@ -246,22 +200,20 @@ class Encoder(nn.Module):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
+
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
- x1 = blk(x1, conv_carry_in, conv_carry_out)
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
- if len(out) > 1:
- out = torch.cat(out, dim=2)
- else:
- out = out[0]
+ out = torch_cat_if_needed(out, dim=2)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
@@ -288,7 +240,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
- conv_op = NoPadConv3d
+ conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -298,9 +250,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
- self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
- self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@@ -308,9 +260,10 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
- stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
- out_channels=tgt,
- conv_op=conv_op, norm_op=norm_op)
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
@@ -340,7 +293,7 @@ class Decoder(nn.Module):
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
- x1 = blk(x1, conv_carry_in, conv_carry_out)
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@@ -350,10 +303,7 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out
del x
- if len(out) > 1:
- out = torch.cat(out, dim=2)
- else:
- out = out[0]
+ out = torch_cat_if_needed(out, dim=2)
if not self.refiner_vae:
if z.shape[-3] == 1:
diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py
new file mode 100644
index 000000000..1509de2f8
--- /dev/null
+++ b/comfy/ldm/kandinsky5/model.py
@@ -0,0 +1,413 @@
+import torch
+from torch import nn
+import math
+
+import comfy.ldm.common_dit
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.flux.math import apply_rope1
+from comfy.ldm.flux.layers import EmbedND
+
+def attention(q, k, v, heads, transformer_options={}):
+ return optimized_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ heads=heads,
+ skip_reshape=True,
+ transformer_options=transformer_options
+ )
+
+def apply_scale_shift_norm(norm, x, scale, shift):
+ return torch.addcmul(shift, norm(x), scale + 1.0)
+
+def apply_gate_sum(x, out, gate):
+ return torch.addcmul(x, gate, out)
+
+def get_shift_scale_gate(params):
+ shift, scale, gate = torch.chunk(params, 3, dim=-1)
+ return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
+
+def get_freqs(dim, max_period=10000.0):
+ return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
+
+
+class TimeEmbeddings(nn.Module):
+ def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
+ super().__init__()
+ assert model_dim % 2 == 0
+ self.model_dim = model_dim
+ self.max_period = max_period
+ self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.activation = nn.SiLU()
+ self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, timestep, dtype):
+ args = torch.outer(timestep, self.freqs.to(device=timestep.device))
+ time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
+ time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
+ return time_embed
+
+
+class TextEmbeddings(nn.Module):
+ def __init__(self, text_dim, model_dim, operation_settings=None):
+ super().__init__()
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, text_embed):
+ text_embed = self.in_layer(text_embed)
+ return self.norm(text_embed).type_as(text_embed)
+
+
+class VisualEmbeddings(nn.Module):
+ def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
+ super().__init__()
+ self.patch_size = patch_size
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x):
+ x = x.movedim(1, -1) # B C T H W -> B T H W C
+ B, T, H, W, dim = x.shape
+ pt, ph, pw = self.patch_size
+
+ x = x.view(
+ B,
+ T // pt, pt,
+ H // ph, ph,
+ W // pw, pw,
+ dim,
+ ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
+
+ return self.in_layer(x)
+
+
+class Modulation(nn.Module):
+ def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
+ super().__init__()
+ self.activation = nn.SiLU()
+ self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x):
+ return self.out_layer(self.activation(x))
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, num_channels, head_dim, operation_settings=None):
+ super().__init__()
+ assert num_channels % head_dim == 0
+ self.num_heads = num_channels // head_dim
+ self.head_dim = head_dim
+
+ operations = operation_settings.get("operations")
+ self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.num_chunks = 2
+
+ def _compute_qk(self, x, freqs, proj_fn, norm_fn):
+ result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
+ return apply_rope1(norm_fn(result), freqs)
+
+ def _forward(self, x, freqs, transformer_options={}):
+ q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
+ k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
+ v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
+ out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+ def _forward_chunked(self, x, freqs, transformer_options={}):
+ def process_chunks(proj_fn, norm_fn):
+ x_chunks = torch.chunk(x, self.num_chunks, dim=1)
+ freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
+ chunks = []
+ for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
+ chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
+ return torch.cat(chunks, dim=1)
+
+ q = process_chunks(self.to_query, self.query_norm)
+ k = process_chunks(self.to_key, self.key_norm)
+ v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
+ out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+ def forward(self, x, freqs, transformer_options={}):
+ if x.shape[1] > 8192:
+ return self._forward_chunked(x, freqs, transformer_options=transformer_options)
+ else:
+ return self._forward(x, freqs, transformer_options=transformer_options)
+
+
+class CrossAttention(SelfAttention):
+ def get_qkv(self, x, context):
+ q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
+ k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
+ v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
+ return q, k, v
+
+ def forward(self, x, context, transformer_options={}):
+ q, k, v = self.get_qkv(x, context)
+ out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
+ return self.out_layer(out)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, ff_dim, operation_settings=None):
+ super().__init__()
+ operations = operation_settings.get("operations")
+ self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.activation = nn.GELU()
+ self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.num_chunks = 4
+
+ def _forward(self, x):
+ return self.out_layer(self.activation(self.in_layer(x)))
+
+ def _forward_chunked(self, x):
+ chunks = torch.chunk(x, self.num_chunks, dim=1)
+ output_chunks = []
+ for chunk in chunks:
+ output_chunks.append(self._forward(chunk))
+ return torch.cat(output_chunks, dim=1)
+
+ def forward(self, x):
+ if x.shape[1] > 8192:
+ return self._forward_chunked(x)
+ else:
+ return self._forward(x)
+
+
+class OutLayer(nn.Module):
+ def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
+ super().__init__()
+ self.patch_size = patch_size
+ self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
+ operations = operation_settings.get("operations")
+ self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, visual_embed, time_embed):
+ B, T, H, W, _ = visual_embed.shape
+ shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
+ scale = scale[:, None, None, None, :]
+ shift = shift[:, None, None, None, :]
+ visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
+ x = self.out_layer(visual_embed)
+
+ out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
+ x = x.view(
+ B, T, H, W,
+ out_dim,
+ self.patch_size[0], self.patch_size[1], self.patch_size[2]
+ )
+ return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
+
+
+class TransformerEncoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
+ super().__init__()
+ self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
+ operations = operation_settings.get("operations")
+
+ self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
+
+ def forward(self, x, time_embed, freqs, transformer_options={}):
+ self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
+ shift, scale, gate = get_shift_scale_gate(self_attn_params)
+ out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
+ out = self.self_attention(out, freqs, transformer_options=transformer_options)
+ x = apply_gate_sum(x, out, gate)
+
+ shift, scale, gate = get_shift_scale_gate(ff_params)
+ out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
+ out = self.feed_forward(out)
+ x = apply_gate_sum(x, out, gate)
+ return x
+
+
+class TransformerDecoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
+ super().__init__()
+ self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
+
+ operations = operation_settings.get("operations")
+ self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
+
+ self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
+
+ def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
+ self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
+ # self attention
+ shift, scale, gate = get_shift_scale_gate(self_attn_params)
+ visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
+ visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ # cross attention
+ shift, scale, gate = get_shift_scale_gate(cross_attn_params)
+ visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
+ visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ # feed forward
+ shift, scale, gate = get_shift_scale_gate(ff_params)
+ visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
+ visual_out = self.feed_forward(visual_out)
+ visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
+ return visual_embed
+
+
+class Kandinsky5(nn.Module):
+ def __init__(
+ self,
+ in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
+ model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
+ axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
+ dtype=None, device=None, operations=None, **kwargs
+ ):
+ super().__init__()
+ head_dim = sum(axes_dims)
+ self.rope_scale_factor = rope_scale_factor
+ self.in_visual_dim = in_visual_dim
+ self.model_dim = model_dim
+ self.patch_size = patch_size
+ self.visual_embed_dim = visual_embed_dim
+ self.dtype = dtype
+ self.device = device
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
+ self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
+ self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
+ self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
+
+ self.text_transformer_blocks = nn.ModuleList(
+ [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
+ )
+
+ self.visual_transformer_blocks = nn.ModuleList(
+ [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
+ )
+
+ self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
+
+ self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
+ self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
+
+ def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
+ steps = seq_len if steps is None else steps
+ seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
+ seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
+ freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
+ return freqs
+
+ def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
+
+ patch_size = self.patch_size
+ t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
+ h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
+ w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
+
+ if steps_t is None:
+ steps_t = t_len
+ if steps_h is None:
+ steps_h = h_len
+ if steps_w is None:
+ steps_w = w_len
+
+ h_start = 0
+ w_start = 0
+ rope_options = transformer_options.get("rope_options", None)
+ if rope_options is not None:
+ t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
+ h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
+ w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
+
+ t_start += rope_options.get("shift_t", 0.0)
+ h_start += rope_options.get("shift_y", 0.0)
+ w_start += rope_options.get("shift_x", 0.0)
+ else:
+ rope_scale_factor = self.rope_scale_factor
+ if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
+ if h * w >= 14080:
+ rope_scale_factor = (1.0, 3.16, 3.16)
+
+ t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
+ h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
+ w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
+
+ img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
+ img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
+ img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
+ img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
+ img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
+
+ freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
+ return freqs
+
+ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
+ patches_replace = transformer_options.get("patches_replace", {})
+ context = self.text_embeddings(context)
+ time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
+
+ for block in self.text_transformer_blocks:
+ context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
+
+ visual_embed = self.visual_embeddings(x)
+ visual_shape = visual_embed.shape[:-1]
+ visual_embed = visual_embed.flatten(1, -2)
+
+ blocks_replace = patches_replace.get("dit", {})
+ transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.visual_transformer_blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
+ visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
+ else:
+ visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
+
+ visual_embed = visual_embed.reshape(*visual_shape, -1)
+ return self.out_layer(visual_embed, time_embed)
+
+ def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
+ original_dims = x.ndim
+ if original_dims == 4:
+ x = x.unsqueeze(2)
+ bs, c, t_len, h, w = x.shape
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+
+ if time_dim_replace is not None:
+ time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
+ x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
+
+ freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
+ freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
+
+ out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
+ if original_dims == 4:
+ out = out.squeeze(2)
+ return out
+
+ def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py
new file mode 100644
index 000000000..fd7ce3b5c
--- /dev/null
+++ b/comfy/ldm/lumina/controlnet.py
@@ -0,0 +1,113 @@
+import torch
+from torch import nn
+
+from .model import JointTransformerBlock
+
+class ZImageControlTransformerBlock(JointTransformerBlock):
+ def __init__(
+ self,
+ layer_id: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ qk_norm: bool,
+ modulation=True,
+ block_id=0,
+ operation_settings=None,
+ ):
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, c, x, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ c = super().forward(c, **kwargs)
+ c_skip = self.after_proj(c)
+ return c_skip, c
+
+class ZImage_Control(torch.nn.Module):
+ def __init__(
+ self,
+ dim: int = 3840,
+ n_heads: int = 30,
+ n_kv_heads: int = 30,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: float = (8.0 / 3.0),
+ norm_eps: float = 1e-5,
+ qk_norm: bool = True,
+ dtype=None,
+ device=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__()
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.additional_in_dim = 0
+ self.control_in_dim = 16
+ n_refiner_layers = 2
+ self.n_control_layers = 6
+ self.control_layers = nn.ModuleList(
+ [
+ ZImageControlTransformerBlock(
+ i,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ block_id=i,
+ operation_settings=operation_settings,
+ )
+ for i in range(self.n_control_layers)
+ ]
+ )
+
+ all_x_embedder = {}
+ patch_size = 2
+ f_patch_size = 1
+ x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
+
+ self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
+ self.control_noise_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ z_image_modulation=True,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+
+ def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
+ patch_size = 2
+ f_patch_size = 1
+ pH = pW = patch_size
+ B, C, H, W = control_context.shape
+ control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
+
+ x_attn_mask = None
+ for layer in self.control_noise_refiner:
+ control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
+ return control_context
+
+ def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
+ return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index c8643eb82..c47df49ca 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -22,6 +22,10 @@ def modulate(x, scale):
# Core NextDiT Model #
#############################################################################
+def clamp_fp16(x):
+ if x.dtype == torch.float16:
+ return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
+ return x
class JointAttention(nn.Module):
"""Multi-head attention module."""
@@ -169,7 +173,7 @@ class FeedForward(nn.Module):
# @torch.compile
def _forward_silu_gating(self, x1, x3):
- return F.silu(x1) * x3
+ return clamp_fp16(F.silu(x1) * x3)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
- )
+ ))
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
- self.feed_forward(
+ clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
- )
+ ))
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
- )
+ ))
)
x = x + self.ffn_norm2(
self.feed_forward(
@@ -373,6 +377,7 @@ class NextDiT(nn.Module):
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
+ clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
@@ -443,6 +448,31 @@ class NextDiT(nn.Module):
),
)
+ self.clip_text_pooled_proj = None
+
+ if clip_text_dim is not None:
+ self.clip_text_dim = clip_text_dim
+ self.clip_text_pooled_proj = nn.Sequential(
+ operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
+ operation_settings.get("operations").Linear(
+ clip_text_dim,
+ clip_text_dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+ self.time_text_embed = nn.Sequential(
+ nn.SiLU(),
+ operation_settings.get("operations").Linear(
+ min(dim, 1024) + clip_text_dim,
+ min(dim, 1024),
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@@ -509,7 +539,7 @@ class NextDiT(nn.Module):
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
- cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
+ cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
@@ -517,15 +547,27 @@ class NextDiT(nn.Module):
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
+ rope_options = transformer_options.get("rope_options", None)
+ h_scale = 1.0
+ w_scale = 1.0
+ h_start = 0
+ w_start = 0
+ if rope_options is not None:
+ h_scale = rope_options.get("scale_y", 1.0)
+ w_scale = rope_options.get("scale_x", 1.0)
+
+ h_start = rope_options.get("shift_y", 0.0)
+ w_start = rope_options.get("shift_x", 0.0)
+
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
- x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
- x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
+ x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
+ x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
- x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
+ x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
@@ -552,7 +594,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
- def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
+ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -569,16 +611,32 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
- transformer_options = kwargs.get("transformer_options", {})
+ if self.clip_text_pooled_proj is not None:
+ pooled = kwargs.get("clip_text_pooled", None)
+ if pooled is not None:
+ pooled = self.clip_text_pooled_proj(pooled)
+ else:
+ pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
+
+ adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
+
+ patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
- x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
- freqs_cis = freqs_cis.to(x.device)
+ img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
+ freqs_cis = freqs_cis.to(img.device)
- for layer in self.layers:
- x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
+ for i, layer in enumerate(self.layers):
+ img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
+ if "double_block" in patches:
+ for p in patches["double_block"]:
+ out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
+ if "img" in out:
+ img[:, cap_size[0]:] = out["img"]
+ if "txt" in out:
+ img[:, :cap_size[0]] = out["txt"]
- x = self.final_layer(x, adaln_input)
- x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
+ img = self.final_layer(img, adaln_input)
+ img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
- return -x
+ return -img
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 7437e0567..a8800ded0 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
+ exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
+ exception_fallback = True
+ if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 4245eedca..681a55db5 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
+def torch_cat_if_needed(xl, dim):
+ if len(xl) > 1:
+ return torch.cat(xl, dim)
+ else:
+ return xl[0]
+
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+class CarriedConv3d(nn.Module):
+ def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
+ super().__init__()
+ self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
+
+ x = xl[0]
+ xl.clear()
+
+ if isinstance(op, CarriedConv3d):
+ if conv_carry_in is None:
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
+ else:
+ carry_len = conv_carry_in[0].shape[2]
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
+ x = torch.cat([conv_carry_in.pop(0), x], dim=2)
+
+ if conv_carry_out is not None:
+ to_push = x[:, :, -2:, :, :].clone()
+ conv_carry_out.append(to_push)
+
+ out = op(x)
+
+ return out
+
+
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
@@ -89,29 +126,24 @@ class Upsample(nn.Module):
stride=1,
padding=1)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
- t = x.shape[2]
- if t > 1:
- a, b = x.split((1, t - 1), dim=2)
- del x
- b = interpolate_up(b, scale_factor)
- else:
- a = x
-
- a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
- if t > 1:
- x = torch.cat((a, b), dim=2)
- else:
- x = a
+ results = []
+ if conv_carry_in is None:
+ first = x[:, :, :1, :, :]
+ results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
+ x = x[:, :, 1:, :, :]
+ if x.shape[2] > 0:
+ results.append(interpolate_up(x, scale_factor))
+ x = torch_cat_if_needed(results, dim=2)
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
- x = self.conv(x)
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x
@@ -127,17 +159,20 @@ class Downsample(nn.Module):
stride=stride,
padding=0)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv:
- if x.ndim == 4:
+ if isinstance(self.conv, CarriedConv3d):
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
+ elif x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
+ x = self.conv(x)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
- x = self.conv(x)
+ x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
@@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
- def forward(self, x, temb=None):
+ def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x
h = self.norm1(h)
- h = self.swish(h)
- h = self.conv1(h)
+ h = [ self.swish(h) ]
+ h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = self.swish(h)
- h = self.dropout(h)
- h = self.conv2(h)
+ h = [ self.dropout(h) ]
+ h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
+ x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else:
x = self.nin_shortcut(x)
@@ -279,6 +314,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
+ oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@@ -289,6 +325,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
+ oom_fallback = True
+ if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
@@ -517,9 +555,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
+ if not attn_resolutions:
+ conv_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -532,6 +575,7 @@ class Encoder(nn.Module):
stride=1,
padding=1)
+ self.time_compress = 1
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
@@ -558,10 +602,15 @@ class Encoder(nn.Module):
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
+ else:
+ self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
+ if time_compress is not None:
+ self.time_compress = time_compress
+
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
@@ -587,15 +636,42 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
- # downsampling
- h = self.conv_in(x)
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](h, temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- if i_level != self.num_resolutions-1:
- h = self.down[i_level].downsample(h)
+
+ if self.carried:
+ xl = [x[:, :, :1, :, :]]
+ if x.shape[2] > self.time_compress:
+ tc = self.time_compress
+ xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
+ x = xl
+ else:
+ x = [x]
+ out = []
+
+ conv_carry_in = None
+
+ for i, x1 in enumerate(x):
+ conv_carry_out = []
+ if i == len(x) - 1:
+ conv_carry_out = None
+
+ # downsampling
+ x1 = [ x1 ]
+ h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
+
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
+ if len(self.down[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.down[i_level].attn[i_block](h1)
+ if i_level != self.num_resolutions-1:
+ h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
+
+ out.append(h1)
+ conv_carry_in = conv_carry_out
+
+ h = torch_cat_if_needed(out, dim=2)
+ del out
# middle
h = self.mid.block_1(h, temb)
@@ -604,15 +680,15 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
+ h = [ nonlinearity(h) ]
+ h = conv_carry_causal_3d(h, self.conv_out)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
@@ -626,12 +702,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
- self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
- conv_out_op = VideoConv3d
+ if not attn_resolutions and resnet_op == ResnetBlock:
+ conv_op = CarriedConv3d
+ conv_out_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
+ conv_out_op = VideoConv3d
+
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -706,29 +788,43 @@ class Decoder(nn.Module):
temb = None
# z to block_in
- h = self.conv_in(z)
+ h = conv_carry_causal_3d([z], self.conv_in)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
+ if self.carried:
+ h = torch.split(h, 2, dim=2)
+ else:
+ h = [ h ]
+ out = []
+
+ conv_carry_in = None
+
# upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](h, temb, **kwargs)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h, **kwargs)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
+ for i, h1 in enumerate(h):
+ conv_carry_out = []
+ if i == len(h) - 1:
+ conv_carry_out = None
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.up[i_level].attn[i_block](h1, **kwargs)
+ if i_level != 0:
+ h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
- # end
- if self.give_pre_end:
- return h
+ h1 = self.norm_out(h1)
+ h1 = [ nonlinearity(h1) ]
+ h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
+ if self.tanh_out:
+ h1 = torch.tanh(h1)
+ out.append(h1)
+ conv_carry_in = conv_carry_out
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h, **kwargs)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
+ out = torch_cat_if_needed(out, dim=2)
+
+ return out
diff --git a/comfy/lora.py b/comfy/lora.py
index 36d26293a..2ed0acb9d 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -313,6 +313,23 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
+ if isinstance(model, comfy.model_base.Lumina2):
+ diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
+ for k in diffusers_keys:
+ if k.endswith(".weight"):
+ to = diffusers_keys[k]
+ key_lora = k[:-len(".weight")]
+ key_map["diffusion_model.{}".format(key_lora)] = to
+ key_map["transformer.{}".format(key_lora)] = to
+ key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
+
+ if isinstance(model, comfy.model_base.Kandinsky5):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ key_map["{}".format(key_lora)] = k
+ key_map["transformer.{}".format(key_lora)] = k
+
return key_map
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 9b76c285e..6b8a8454d 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
+import comfy.ldm.kandinsky5.model
import comfy.model_management
import comfy.patcher_extension
@@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
- operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
+ operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -329,18 +330,6 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
-
- if self.model_config.scaled_fp8 is not None:
- unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
-
- # Save mixed precision metadata
- if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
- metadata = {
- "format_version": "1.0",
- "layers": self.model_config.layer_quant_config
- }
- unet_state_dict["_quantization_metadata"] = metadata
-
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@@ -1121,6 +1110,10 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
+ clip_text_pooled = kwargs["pooled_output"] # Newbie
+ if clip_text_pooled is not None:
+ out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
+
return out
class WAN21(BaseModel):
@@ -1642,3 +1635,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out
+
+class Kandinsky5(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
+
+ def encode_adm(self, **kwargs):
+ return kwargs["pooled_output"]
+
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ device = kwargs["device"]
+ image = torch.zeros_like(noise)
+
+ mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+ if mask is None:
+ mask = torch.zeros_like(noise)[:, :1]
+ else:
+ mask = 1.0 - mask
+ mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ if mask.shape[-3] < noise.shape[-3]:
+ mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
+ mask = utils.resize_to_batch_size(mask, noise.shape[0])
+
+ return torch.cat((image, mask), dim=1)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ time_dim_replace = kwargs.get("time_dim_replace", None)
+ if time_dim_replace is not None:
+ out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
+
+ return out
+
+class Kandinsky5Image(Kandinsky5):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device)
+
+ def concat_cond(self, **kwargs):
+ return None
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 7afe4a798..74c547427 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -6,20 +6,6 @@ import math
import logging
import torch
-
-def detect_layer_quantization(metadata):
- quant_key = "_quantization_metadata"
- if metadata is not None and quant_key in metadata:
- quant_metadata = metadata.pop(quant_key)
- quant_metadata = json.loads(quant_metadata)
- if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
- logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
- return quant_metadata["layers"]
- else:
- raise ValueError("Invalid quantization metadata format")
- return None
-
-
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@@ -208,12 +194,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
- dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
+ dit_config["txt_ids_dims"] = [3]
patch_size = 1
else:
dit_config["image_model"] = "flux"
@@ -223,6 +209,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
+ dit_config["txt_ids_dims"] = []
patch_size = 2
dit_config["in_channels"] = 16
@@ -245,6 +232,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
+ else:
+ dit_config["vec_in_dim"] = None
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
@@ -270,6 +259,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_embedder_dtype"] = torch.float32
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
+ dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
+ dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
+ if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
+ dit_config["txt_ids_dims"] = [1, 2]
+
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
@@ -429,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
+ ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
+ if ctd_weight is not None:
+ dit_config["clip_text_dim"] = ctd_weight.shape[0]
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
@@ -617,6 +614,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
+ if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
+ dit_config = {}
+ model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["model_dim"] = model_dim
+ if model_dim in [4096, 2560]: # pro video and lite image
+ dit_config["axes_dims"] = (32, 48, 48)
+ if model_dim == 2560: # lite image
+ dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
+ elif model_dim == 1792: # lite video
+ dit_config["axes_dims"] = (16, 24, 24)
+ dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["image_model"] = "kandinsky5"
+ dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
+ dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
+ dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
+ dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
+ return dit_config
+
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -759,22 +774,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
- scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
- if scaled_fp8_key in state_dict:
- scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
- model_config.scaled_fp8 = scaled_fp8_weight.dtype
- if model_config.scaled_fp8 == torch.float32:
- model_config.scaled_fp8 = torch.float8_e4m3fn
- if scaled_fp8_weight.nelement() == 2:
- model_config.optimizations["fp8"] = False
- else:
- model_config.optimizations["fp8"] = True
-
# Detect per-layer quantization (mixed precision)
- layer_quant_config = detect_layer_quantization(metadata)
- if layer_quant_config:
- model_config.layer_quant_config = layer_quant_config
- logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
+ quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
+ if quant_config:
+ model_config.quant_config = quant_config
+ logging.info("Detected mixed precision quantization")
return model_config
diff --git a/comfy/model_management.py b/comfy/model_management.py
index a9327ac80..40717b1e4 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
- lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
+ lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
@@ -1012,9 +1012,18 @@ def force_channels_last():
STREAMS = {}
-NUM_STREAMS = 1
-if args.async_offload:
- NUM_STREAMS = 2
+NUM_STREAMS = 0
+if args.async_offload is not None:
+ NUM_STREAMS = args.async_offload
+else:
+ # Enable by default on Nvidia
+ if is_nvidia():
+ NUM_STREAMS = 2
+
+if args.disable_async_offload:
+ NUM_STREAMS = 0
+
+if NUM_STREAMS > 0:
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
@@ -1030,7 +1039,10 @@ def current_stream(device):
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
- if NUM_STREAMS <= 1:
+ if NUM_STREAMS == 0:
+ return None
+
+ if torch.compiler.is_compiling():
return None
if device in STREAMS:
@@ -1043,7 +1055,9 @@ def get_offload_stream(device):
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
- ss.append(torch.cuda.Stream(device=device, priority=0))
+ s1 = torch.cuda.Stream(device=device, priority=0)
+ s1.as_context = torch.cuda.stream
+ ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counters[device] = stream_counter
@@ -1051,7 +1065,9 @@ def get_offload_stream(device):
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
- ss.append(torch.xpu.Stream(device=device, priority=0))
+ s1 = torch.xpu.Stream(device=device, priority=0)
+ s1.as_context = torch.xpu.stream
+ ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counters[device] = stream_counter
@@ -1069,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
- with stream:
+ wf_context = stream
+ if hasattr(wf_context, "as_context"):
+ wf_context = wf_context.as_context(stream)
+ with wf_context:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
+
if stream is not None:
- with stream:
+ wf_context = stream
+ if hasattr(wf_context, "as_context"):
+ wf_context = wf_context.as_context(stream)
+ with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
@@ -1469,6 +1492,20 @@ def extended_fp16_support():
return True
+LORA_COMPUTE_DTYPES = {}
+def lora_compute_dtype(device):
+ dtype = LORA_COMPUTE_DTYPES.get(device, None)
+ if dtype is not None:
+ return dtype
+
+ if should_use_fp16(device):
+ dtype = torch.float16
+ else:
+ dtype = torch.float32
+
+ LORA_COMPUTE_DTYPES[device] = dtype
+ return dtype
+
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 73adc7f70..a486c2723 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
+from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@@ -126,27 +127,23 @@ class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
- self.convert_func = convert_func
+ self.convert_func = convert_func # TODO: remove
self.set_func = set_func
def __call__(self, weight):
- intermediate_dtype = weight.dtype
- if self.convert_func is not None:
- weight = self.convert_func(weight, inplace=False)
+ return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
- if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
- intermediate_dtype = torch.float32
- out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
- if self.set_func is None:
- return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
- else:
- return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
+LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
- out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
- if self.set_func is not None:
- return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
- else:
- return out
+def low_vram_patch_estimate_vram(model, key):
+ weight, set_func, convert_func = get_key_weight(model, key)
+ if weight is None:
+ return 0
+ model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
+ if model_dtype is None:
+ model_dtype = weight.dtype
+
+ return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@@ -269,6 +266,9 @@ class ModelPatcher:
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None
+ if not hasattr(self.model, 'model_offload_buffer_memory'):
+ self.model.model_offload_buffer_memory = 0
+
def model_size(self):
if self.size > 0:
return self.size
@@ -618,10 +618,11 @@ class ModelPatcher:
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
+ temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
if device_to is not None:
- temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
+ temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
- temp_weight = weight.to(torch.float32, copy=True)
+ temp_weight = weight.to(temp_dtype, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
@@ -662,7 +663,22 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
- loading.append((comfy.model_management.module_size(m), n, m, params))
+ module_mem = comfy.model_management.module_size(m)
+ module_offload_mem = module_mem
+ if hasattr(m, "comfy_cast_weights"):
+ def check_module_offload_mem(key):
+ if key in self.patches:
+ return low_vram_patch_estimate_vram(self.model, key)
+ model_dtype = getattr(self.model, "manual_cast_dtype", None)
+ weight, _, _ = get_key_weight(self.model, key)
+ if model_dtype is None or weight is None:
+ return 0
+ if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
+ return weight.numel() * model_dtype.itemsize
+ return 0
+ module_offload_mem += check_module_offload_mem("{}.weight".format(n))
+ module_offload_mem += check_module_offload_mem("{}.bias".format(n))
+ loading.append((module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -676,20 +692,22 @@ class ModelPatcher:
load_completely = []
offloaded = []
+ offload_buffer = 0
loading.sort(reverse=True)
- for x in loading:
- n = x[1]
- m = x[2]
- params = x[3]
- module_mem = x[0]
+ for i, x in enumerate(loading):
+ module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
+ potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
+ lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
+
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
- if mem_counter + module_mem >= lowvram_model_memory:
+ if not lowvram_fits:
+ offload_buffer = potential_offload
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
@@ -723,9 +741,11 @@ class ModelPatcher:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
- if full_load or mem_counter + module_mem < lowvram_model_memory:
+ if full_load or lowvram_fits:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
+ else:
+ offload_buffer = potential_offload
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
@@ -752,6 +772,8 @@ class ModelPatcher:
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
+ if comfy.model_management.is_device_cuda(device_to):
+ torch.cuda.synchronize()
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -766,7 +788,7 @@ class ModelPatcher:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
- logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
+ logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
@@ -778,6 +800,7 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
+ self.model.model_offload_buffer_memory = offload_buffer
self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@@ -831,6 +854,7 @@ class ModelPatcher:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
+ self.model.model_offload_buffer_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
@@ -849,13 +873,18 @@ class ModelPatcher:
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
+
+ offload_buffer = self.model.model_offload_buffer_memory
+ if len(unload_list) > 0:
+ NS = comfy.model_management.NUM_STREAMS
+ offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
+
for unload in unload_list:
- if memory_to_free < memory_freed:
+ if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
- module_mem = unload[0]
- n = unload[1]
- m = unload[2]
- params = unload[3]
+ module_offload_mem, module_mem, n, m, params = unload
+
+ potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@@ -901,20 +930,25 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
- if cast_weight:
+ if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
+ offload_buffer = max(offload_buffer, potential_offload)
+ offload_weight_factor.append(module_mem)
+ offload_weight_factor.pop(0)
logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
+
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
- logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
+ self.model.model_offload_buffer_memory = offload_buffer
+ logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
diff --git a/comfy/ops.py b/comfy/ops.py
index a0ff4e8f1..35237c9f7 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
+import json
def run_every_op():
if torch.compiler.is_compiling():
@@ -95,6 +96,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if offload_stream is not None:
wf_context = offload_stream
+ if hasattr(wf_context, "as_context"):
+ wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
@@ -109,22 +112,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
- if bias_has_function:
- with wf_context:
- for f in s.bias_function:
- bias = f(bias)
+ comfy.model_management.sync_stream(device, offload_stream)
+
+ bias_a = bias
+ weight_a = weight
+
+ if s.bias is not None:
+ for f in s.bias_function:
+ bias = f(bias)
if weight_has_function or weight.dtype != dtype:
- with wf_context:
- weight = weight.to(dtype=dtype)
- if isinstance(weight, QuantizedTensor):
- weight = weight.dequantize()
- for f in s.weight_function:
- weight = f(weight)
+ weight = weight.to(dtype=dtype)
+ if isinstance(weight, QuantizedTensor):
+ weight = weight.dequantize()
+ for f in s.weight_function:
+ weight = f(weight)
- comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
- return weight, bias, offload_stream
+ return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
@@ -133,13 +138,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
- if weight is not None:
- device = weight.device
+ os, weight_a, bias_a = offload_stream
+ if os is None:
+ return
+ if weight_a is not None:
+ device = weight_a.device
else:
- if bias is None:
+ if bias_a is None:
return
- device = bias.device
- offload_stream.wait_stream(comfy.model_management.current_stream(device))
+ device = bias_a.device
+ os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
@@ -415,22 +423,12 @@ def fp8_linear(self, input):
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
+ scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
- scale_weight = self.scale_weight
- scale_input = self.scale_input
- if scale_weight is None:
- scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
- else:
- scale_weight = scale_weight.to(input.device)
-
- if scale_input is None:
- scale_input = torch.ones((), device=input.device, dtype=torch.float32)
- input = torch.clamp(input, min=-448, max=448, out=input)
- layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
- quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
- else:
- scale_input = scale_input.to(input.device)
- quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
+ scale_input = torch.ones((), device=input.device, dtype=torch.float32)
+ input = torch.clamp(input, min=-448, max=448, out=input)
+ layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
+ quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
@@ -451,7 +449,7 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
- if not self.training:
+ if len(self.weight_function) == 0 and len(self.bias_function) == 0:
try:
out = fp8_linear(self, input)
if out is not None:
@@ -464,59 +462,6 @@ class fp8_ops(manual_cast):
uncast_bias_weight(self, weight, bias, offload_stream)
return x
-def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
- logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
- class scaled_fp8_op(manual_cast):
- class Linear(manual_cast.Linear):
- def __init__(self, *args, **kwargs):
- if override_dtype is not None:
- kwargs['dtype'] = override_dtype
- super().__init__(*args, **kwargs)
-
- def reset_parameters(self):
- if not hasattr(self, 'scale_weight'):
- self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
-
- if not scale_input:
- self.scale_input = None
-
- if not hasattr(self, 'scale_input'):
- self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
- return None
-
- def forward_comfy_cast_weights(self, input):
- if fp8_matrix_mult:
- out = fp8_linear(self, input)
- if out is not None:
- return out
-
- weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
-
- if weight.numel() < input.numel(): #TODO: optimize
- x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
- else:
- x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
- uncast_bias_weight(self, weight, bias, offload_stream)
- return x
-
- def convert_weight(self, weight, inplace=False, **kwargs):
- if inplace:
- weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
- return weight
- else:
- return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
-
- def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
- weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
- if return_weight:
- return weight
- if inplace_update:
- self.weight.data.copy_(weight)
- else:
- self.weight = torch.nn.Parameter(weight, requires_grad=False)
-
- return scaled_fp8_op
-
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
@@ -543,9 +488,9 @@ if CUBLAS_IS_AVAILABLE:
from .quant_ops import QuantizedTensor, QUANT_ALGOS
-def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
+def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
- _layer_quant_config = layer_quant_config
+ _quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
@@ -588,27 +533,38 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
manually_loaded_keys = [weight_key]
- if layer_name not in MixedPrecisionOps._layer_quant_config:
+ layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
+ if layer_conf is not None:
+ layer_conf = json.loads(layer_conf.numpy().tobytes())
+
+ if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
- quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
- if quant_format is None:
+ self.quant_format = layer_conf.get("format", None)
+ if not self._full_precision_mm:
+ self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
+
+ if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
- qconfig = QUANT_ALGOS[quant_format]
+ qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
+ scale = state_dict.pop(weight_scale_key, None)
+ if scale is not None:
+ scale = scale.to(device)
layout_params = {
- 'scale': state_dict.pop(weight_scale_key, None),
+ 'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
- if layout_params['scale'] is not None:
+
+ if scale is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
- QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
+ QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
requires_grad=False
)
@@ -617,7 +573,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
_v = state_dict.pop(param_key, None)
if _v is None:
continue
- setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
+ self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@@ -626,6 +582,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
if key in missing_keys:
missing_keys.remove(key)
+ def state_dict(self, *args, destination=None, prefix="", **kwargs):
+ sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
+ if isinstance(self.weight, QuantizedTensor):
+ sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
+ quant_conf = {"format": self.quant_format}
+ if self._full_precision_mm:
+ quant_conf["full_precision_matrix_mult"] = True
+ sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
+ return sd
+
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@@ -641,9 +607,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
- getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
- input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
+ input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
@@ -654,7 +619,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
- weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
+ weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
@@ -663,17 +628,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
+ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
+ if recurse:
+ for module in self.children():
+ module._apply(fn)
+
+ for key, param in self._parameters.items():
+ if param is None:
+ continue
+ self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
+ for key, buf in self._buffers.items():
+ if buf is not None:
+ self._buffers[key] = fn(buf)
+ return self
+
return MixedPrecisionOps
-def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
+def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
- if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
- logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
- return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
-
- if scaled_fp8 is not None:
- return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
+ if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
+ logging.info("Using mixed precision operations")
+ return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if (
fp8_compute and
diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py
index d2f3e7397..571d3f760 100644
--- a/comfy/quant_ops.py
+++ b/comfy/quant_ops.py
@@ -235,8 +235,11 @@ class QuantizedTensor(torch.Tensor):
def is_pinned(self):
return self._qdata.is_pinned()
- def is_contiguous(self):
- return self._qdata.is_contiguous()
+ def is_contiguous(self, *arg, **kwargs):
+ return self._qdata.is_contiguous(*arg, **kwargs)
+
+ def storage(self):
+ return self._qdata.storage()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
@@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
- if target_dtype is not None and target_dtype != qt.dtype:
- logging.warning(
- f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
- f"but not supported for quantized tensors. Ignoring dtype."
- )
-
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
@@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
+ if target_dtype is not None:
+ new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
@@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
+ orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
+ qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
@@ -397,17 +398,20 @@ class TensorCoreFP8Layout(QuantizedLayout):
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
- if scale is None:
+ if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
- if not isinstance(scale, torch.Tensor):
- scale = torch.tensor(scale)
- scale = scale.to(device=tensor.device, dtype=torch.float32)
+ if scale is not None:
+ if not isinstance(scale, torch.Tensor):
+ scale = torch.tensor(scale)
+ scale = scale.to(device=tensor.device, dtype=torch.float32)
- if inplace_ops:
- tensor *= (1.0 / scale).to(tensor.dtype)
+ if inplace_ops:
+ tensor *= (1.0 / scale).to(tensor.dtype)
+ else:
+ tensor = tensor * (1.0 / scale).to(tensor.dtype)
else:
- tensor = tensor * (1.0 / scale).to(tensor.dtype)
+ scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
@@ -425,7 +429,8 @@ class TensorCoreFP8Layout(QuantizedLayout):
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
- return plain_tensor * scale
+ plain_tensor.mul_(scale)
+ return plain_tensor
@classmethod
def get_plain_tensors(cls, qtensor):
diff --git a/comfy/sd.py b/comfy/sd.py
index 350fae92b..a16f2d14f 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -53,6 +53,8 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
+import comfy.text_encoders.ovis
+import comfy.text_encoders.kandinsky5
import comfy.model_patcher
import comfy.lora
@@ -60,6 +62,8 @@ import comfy.lora_convert
import comfy.hooks
import comfy.t2i_adapter.adapter
import comfy.taesd.taesd
+import comfy.taesd.taehv
+import comfy.latent_formats
import comfy.ldm.flux.redux
@@ -95,7 +99,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP:
- def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
+ def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init:
return
params = target.params.copy()
@@ -123,9 +127,32 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
+ #Match torch.float32 hardcode upcast in TE implemention
+ self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
+ if len(state_dict) > 0:
+ if isinstance(state_dict, list):
+ for c in state_dict:
+ m, u = self.load_sd(c)
+ if len(m) > 0:
+ logging.warning("clip missing: {}".format(m))
+
+ if len(u) > 0:
+ logging.debug("clip unexpected: {}".format(u))
+ else:
+ m, u = self.load_sd(state_dict, full_model=True)
+ if len(m) > 0:
+ m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
+ if len(m_filter) > 0:
+ logging.warning("clip missing: {}".format(m))
+ else:
+ logging.debug("clip missing: {}".format(m))
+
+ if len(u) > 0:
+ logging.debug("clip unexpected {}:".format(u))
+
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
@@ -190,6 +217,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
@@ -237,6 +265,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
@@ -466,7 +495,7 @@ class VAE:
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
- self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@@ -478,8 +507,10 @@ class VAE:
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
- self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
- self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
+ #This is likely to significantly over-estimate with single image or low frame counts as the
+ #implementation is able to completely skip caching. Rework if used as an image only VAE
+ self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
+ self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
@@ -508,13 +539,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
+ dim = sd["decoder.head.0.gamma"].shape[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
- ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
+ ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
@@ -584,6 +616,35 @@ class VAE:
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
+ elif "decoder.22.bias" in sd: # taehv, taew and lighttae
+ self.latent_channels = sd["decoder.1.weight"].shape[1]
+ self.latent_dim = 3
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
+ self.upscale_index_formula = (4, 16, 16)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
+ self.downscale_index_formula = (4, 16, 16)
+ if self.latent_channels == 48: # Wan 2.2
+ self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
+ self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
+ self.process_output = lambda image: image
+ self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
+ elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
+ self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
+ self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
+ self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
+ else:
+ if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
+ latent_format=comfy.latent_formats.HunyuanVideo
+ else:
+ latent_format=None # lighttaew2_1 doesn't need scaling
+ self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
+ self.process_input = self.process_output = lambda image: image
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
+ self.upscale_index_formula = (4, 8, 8)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
+ self.downscale_index_formula = (4, 8, 8)
+ self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
+ self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -708,6 +769,8 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
+ if self.latent_dim == 2 and samples_in.ndim == 5:
+ samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -924,16 +987,17 @@ class CLIPType(Enum):
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
+ OVIS = 21
+ KANDINSKY5 = 22
+ KANDINSKY5_IMAGE = 23
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
- if metadata is not None:
- quant_metadata = metadata.get("_quantization_metadata", None)
- if quant_metadata is not None:
- sd["_quantization_metadata"] = quant_metadata
+ if model_options.get("custom_operations", None) is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@@ -955,6 +1019,7 @@ class TEModel(Enum):
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
+ QWEN3_2B = 17
def detect_te_model(sd):
@@ -988,9 +1053,12 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
- if 'model.layers.0.self_attn.q_norm.weight' in sd:
- return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
+ if 'model.layers.0.self_attn.q_norm.weight' in sd:
+ if weight.shape[0] == 2560:
+ return TEModel.QWEN3_4B
+ elif weight.shape[0] == 2048:
+ return TEModel.QWEN3_2B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@@ -1046,7 +1114,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
- clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
@@ -1070,7 +1138,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
- clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
+ clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
@@ -1099,7 +1167,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
- clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
+ clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
@@ -1118,13 +1186,16 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
+ elif te_model == TEModel.QWEN3_2B:
+ clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
- clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
@@ -1167,6 +1238,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
+ elif clip_type == CLIPType.KANDINSKY5:
+ clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
+ elif clip_type == CLIPType.KANDINSKY5_IMAGE:
+ clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -1179,19 +1256,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0
for c in clip_data:
- if "_quantization_metadata" in c:
- c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
- clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
- for c in clip_data:
- m, u = clip.load_sd(c)
- if len(m) > 0:
- logging.warning("clip missing: {}".format(m))
-
- if len(u) > 0:
- logging.debug("clip unexpected: {}".format(u))
+ clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
return clip
def load_gligen(ckpt_path):
@@ -1250,6 +1318,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
+ custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
+
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
@@ -1258,18 +1330,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
-
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if model_config.scaled_fp8 is not None:
+ if model_config.quant_config is not None:
weight_dtype = None
- model_config.custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is not None:
+ model_config.custom_operations = custom_operations
+
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ if model_config.quant_config is not None:
+ manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
+ else:
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None:
@@ -1287,22 +1363,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
+ if te_model_options.get("custom_operations", None) is None:
+ scaled_fp8_list = []
+ for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
+ if k.endswith(".scaled_fp8"):
+ scaled_fp8_list.append(k[:-len("scaled_fp8")])
+
+ if len(scaled_fp8_list) > 0:
+ out_sd = {}
+ for k in sd:
+ skip = False
+ for pref in scaled_fp8_list:
+ skip = skip or k.startswith(pref)
+ if not skip:
+ out_sd[k] = sd[k]
+
+ for pref in scaled_fp8_list:
+ quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
+ for k in quant_sd:
+ out_sd[k] = quant_sd[k]
+ sd = out_sd
+
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
- clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
- m, u = clip.load_sd(clip_sd, full_model=True)
- if len(m) > 0:
- m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
- if len(m_filter) > 0:
- logging.warning("clip missing: {}".format(m))
- else:
- logging.debug("clip missing: {}".format(m))
-
- if len(u) > 0:
- logging.debug("clip unexpected {}:".format(u))
+ clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
@@ -1349,6 +1436,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
if len(temp_sd) > 0:
sd = temp_sd
+ custom_operations = model_options.get("custom_operations", None)
+ if custom_operations is None:
+ sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
@@ -1379,7 +1469,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if model_config.scaled_fp8 is not None:
+ if model_config.quant_config is not None:
weight_dtype = None
if dtype is None:
@@ -1387,12 +1477,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else:
unet_dtype = dtype
- if model_config.layer_quant_config is not None:
+ if model_config.quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
- model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
+
+ if custom_operations is not None:
+ model_config.custom_operations = custom_operations
+
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
@@ -1431,6 +1524,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if vae is not None:
vae_sd = vae.get_sd()
+ if metadata is None:
+ metadata = {}
+
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 0fc9ab3db..962948dae 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config[k] = v
operations = model_options.get("custom_operations", None)
- scaled_fp8 = None
- quantization_metadata = model_options.get("quantization_metadata", None)
+ quant_config = model_options.get("quantization_metadata", None)
if operations is None:
- layer_quant_config = None
- if quantization_metadata is not None:
- layer_quant_config = json.loads(quantization_metadata).get("layers", None)
-
- if layer_quant_config is not None:
- operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
- logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
+ if quant_config is not None:
+ operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
+ logging.info("Using MixedPrecisionOps for text encoder")
else:
- # Fallback to scaled_fp8_ops for backward compatibility
- scaled_fp8 = model_options.get("scaled_fp8", None)
- if scaled_fp8 is not None:
- operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
- else:
- operations = comfy.ops.manual_cast
+ operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
- if scaled_fp8 is not None:
- self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
@@ -147,6 +135,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
+ self.execution_device = None
if layer == "hidden":
assert layer_idx is not None
@@ -163,6 +152,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
+ self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
@@ -175,6 +165,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
+ self.execution_device = None
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
@@ -258,7 +249,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
- device = self.transformer.get_input_embeddings().weight.device
+ if self.execution_device is None:
+ device = self.transformer.get_input_embeddings().weight.device
+ else:
+ device = self.execution_device
+
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index af8120400..383c82c3e 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
+import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
from . import supported_models_base
@@ -1027,6 +1028,8 @@ class ZImage(Lumina2):
memory_usage_factor = 1.7
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
@@ -1472,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
+class Kandinsky5(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "kandinsky5",
+ }
+
+ sampling_settings = {
+ "shift": 10.0,
+ }
+
+ unet_extra_config = {}
+ latent_format = latent_formats.HunyuanVideo
+
+ memory_usage_factor = 1.1 #TODO
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Kandinsky5(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
+
+
+class Kandinsky5Image(Kandinsky5):
+ unet_config = {
+ "image_model": "kandinsky5",
+ "model_dim": 2560,
+ "visual_embed_dim": 64,
+ }
+
+ sampling_settings = {
+ "shift": 3.0,
+ }
+
+ latent_format = latent_formats.Flux
+ memory_usage_factor = 1.1 #TODO
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Kandinsky5Image(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
+
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index e4bd74514..0e7a829ba 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -17,6 +17,7 @@
"""
import torch
+import logging
from . import model_base
from . import utils
from . import latent_formats
@@ -49,8 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
- scaled_fp8 = None
- layer_quant_config = None # Per-layer quantization configuration for mixed precision
+ quant_config = None # quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod
@@ -118,3 +118,7 @@ class BASE:
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
+
+ def __getattr__(self, name):
+ logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
+ return None
diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py
new file mode 100644
index 000000000..3dfe1e4d4
--- /dev/null
+++ b/comfy/taesd/taehv.py
@@ -0,0 +1,171 @@
+# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm.auto import tqdm
+from collections import namedtuple, deque
+
+import comfy.ops
+operations=comfy.ops.disable_weight_init
+
+DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
+TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
+
+def conv(n_in, n_out, **kwargs):
+ return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+class Clamp(nn.Module):
+ def forward(self, x):
+ return torch.tanh(x / 3) * 3
+
+class MemBlock(nn.Module):
+ def __init__(self, n_in, n_out, act_func):
+ super().__init__()
+ self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
+ self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.act = act_func
+ def forward(self, x, past):
+ return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
+
+class TPool(nn.Module):
+ def __init__(self, n_f, stride):
+ super().__init__()
+ self.stride = stride
+ self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
+ def forward(self, x):
+ _NT, C, H, W = x.shape
+ return self.conv(x.reshape(-1, self.stride * C, H, W))
+
+class TGrow(nn.Module):
+ def __init__(self, n_f, stride):
+ super().__init__()
+ self.stride = stride
+ self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
+ def forward(self, x):
+ _NT, C, H, W = x.shape
+ x = self.conv(x)
+ return x.reshape(-1, C, H, W)
+
+def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
+
+ B, T, C, H, W = x.shape
+ if parallel:
+ x = x.reshape(B*T, C, H, W)
+ # parallel over input timesteps, iterate over blocks
+ for b in tqdm(model, disable=not show_progress_bar):
+ if isinstance(b, MemBlock):
+ BT, C, H, W = x.shape
+ T = BT // B
+ _x = x.reshape(B, T, C, H, W)
+ mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
+ x = b(x, mem)
+ else:
+ x = b(x)
+ BT, C, H, W = x.shape
+ T = BT // B
+ x = x.view(B, T, C, H, W)
+ else:
+ out = []
+ work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
+ progress_bar = tqdm(range(T), disable=not show_progress_bar)
+ mem = [None] * len(model)
+ while work_queue:
+ xt, i = work_queue.popleft()
+ if i == 0:
+ progress_bar.update(1)
+ if i == len(model):
+ out.append(xt)
+ del xt
+ else:
+ b = model[i]
+ if isinstance(b, MemBlock):
+ if mem[i] is None:
+ xt_new = b(xt, xt * 0)
+ mem[i] = xt.detach().clone()
+ else:
+ xt_new = b(xt, mem[i])
+ mem[i] = xt.detach().clone()
+ del xt
+ work_queue.appendleft(TWorkItem(xt_new, i+1))
+ elif isinstance(b, TPool):
+ if mem[i] is None:
+ mem[i] = []
+ mem[i].append(xt.detach().clone())
+ if len(mem[i]) == b.stride:
+ B, C, H, W = xt.shape
+ xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
+ mem[i] = []
+ work_queue.appendleft(TWorkItem(xt, i+1))
+ elif isinstance(b, TGrow):
+ xt = b(xt)
+ NT, C, H, W = xt.shape
+ for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
+ work_queue.appendleft(TWorkItem(xt_next, i+1))
+ del xt
+ else:
+ xt = b(xt)
+ work_queue.appendleft(TWorkItem(xt, i+1))
+ progress_bar.close()
+ x = torch.stack(out, 1)
+ return x
+
+
+class TAEHV(nn.Module):
+ def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
+ super().__init__()
+ self.image_channels = 3
+ self.patch_size = 1
+ self.latent_channels = latent_channels
+ self.parallel = parallel
+ self.latent_format = latent_format
+ self.show_progress_bar = show_progress_bar
+ self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
+ self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
+ if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
+ self.patch_size = 2
+ if self.latent_channels == 32: # HunyuanVideo1.5
+ act_func = nn.LeakyReLU(0.2, inplace=True)
+ else: # HunyuanVideo, Wan 2.1
+ act_func = nn.ReLU(inplace=True)
+
+ self.encoder = nn.Sequential(
+ conv(self.image_channels*self.patch_size**2, 64), act_func,
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
+ TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
+ conv(64, self.latent_channels),
+ )
+ n_f = [256, 128, 64, 64]
+ self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
+ self.decoder = nn.Sequential(
+ Clamp(), conv(self.latent_channels, n_f[0]), act_func,
+ MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
+ MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
+ MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
+ act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
+ )
+ @property
+ def show_progress_bar(self):
+ return self._show_progress_bar
+
+ @show_progress_bar.setter
+ def show_progress_bar(self, value):
+ self._show_progress_bar = value
+
+ def encode(self, x, **kwargs):
+ if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
+ x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
+ if x.shape[1] % 4 != 0:
+ # pad at end to multiple of 4
+ n_pad = 4 - x.shape[1] % 4
+ padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
+ x = torch.cat([x, padding], 1)
+ x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
+ return self.process_out(x)
+
+ def decode(self, x, **kwargs):
+ x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
+ x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
+ if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
+ return x[:, self.frames_to_trim:].movedim(2, 1)
diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py
index a1adb5242..448381fa9 100644
--- a/comfy/text_encoders/cosmos.py
+++ b/comfy/text_encoders/cosmos.py
@@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
- t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
- if t5xxl_scaled_fp8 is not None:
+ t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
+ if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["quantization_metadata"] = t5xxl_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
@@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
-def te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def te(dtype_t5=None, t5_quantization_metadata=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py
index 99f4812bb..21d93d757 100644
--- a/comfy/text_encoders/flux.py
+++ b/comfy/text_encoders/flux.py
@@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
-def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
+def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_
@@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
-def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
+def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
- model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()
diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py
index 9dcf190a2..5daea8135 100644
--- a/comfy/text_encoders/genmo.py
+++ b/comfy/text_encoders/genmo.py
@@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
-def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py
index dbcf52784..600b34480 100644
--- a/comfy/text_encoders/hidream.py
+++ b/comfy/text_encoders/hidream.py
@@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
return self.llama.load_sd(sd)
-def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
+def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
- if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["llama_scaled_fp8"] = llama_scaled_fp8
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_
diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py
index ff04726e1..cd198036c 100644
--- a/comfy/text_encoders/hunyuan_image.py
+++ b/comfy/text_encoders/hunyuan_image.py
@@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
- llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
- if llama_scaled_fp8 is not None:
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
else:
return super().load_sd(sd)
-def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
+def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["qwen_scaled_fp8"] = llama_scaled_fp8
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py
index 0110517bb..a9a6c525e 100644
--- a/comfy/text_encoders/hunyuan_video.py
+++ b/comfy/text_encoders/hunyuan_video.py
@@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
import torch
import os
import numbers
-
+import comfy.utils
def llama_detect(state_dict, prefix=""):
out = {}
@@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
- scaled_fp8_key = "{}scaled_fp8".format(prefix)
- if scaled_fp8_key in state_dict:
- out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
-
- if "_quantization_metadata" in state_dict:
- out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
+ quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
+ if quant is not None:
+ out["llama_quantization_metadata"] = quant
return out
@@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
- llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
- if llama_scaled_fp8 is not None:
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
textmodel_json_config = {}
vocab_size = model_options.get("vocab_size", None)
@@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
return self.llama.load_sd(sd)
-def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
+def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["llama_scaled_fp8"] = llama_scaled_fp8
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_
diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py
new file mode 100644
index 000000000..be086458c
--- /dev/null
+++ b/comfy/text_encoders/kandinsky5.py
@@ -0,0 +1,68 @@
+from comfy import sd1_clip
+from .qwen_image import QwenImageTokenizer, QwenImageTEModel
+from .llama import Qwen25_7BVLI
+
+
+class Kandinsky5Tokenizer(QwenImageTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
+ out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
+
+ return out
+
+
+class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
+
+
+class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+
+class Kandinsky5TEModel(QwenImageTEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
+
+ def encode_token_weights(self, token_weight_pairs):
+ cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
+ l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
+
+ return cond, l_pooled, extra
+
+ def set_clip_options(self, options):
+ super().set_clip_options(options)
+ self.clip_l.set_clip_options(options)
+
+ def reset_clip_options(self):
+ super().reset_clip_options()
+ self.clip_l.reset_clip_options()
+
+ def load_sd(self, sd):
+ if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
+ return self.clip_l.load_sd(sd)
+ else:
+ return super().load_sd(sd)
+
+def te(dtype_llama=None, llama_quantization_metadata=None):
+ class Kandinsky5TEModel_(Kandinsky5TEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return Kandinsky5TEModel_
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index cd4b5f76c..0d07ac8c6 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -100,6 +100,28 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
+@dataclass
+class Ovis25_2BConfig:
+ vocab_size: int = 151936
+ hidden_size: int = 2048
+ intermediate_size: int = 6144
+ num_hidden_layers: int = 28
+ num_attention_heads: int = 16
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 40960
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
@@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
+class Ovis25_2B(BaseLlama, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Ovis25_2BConfig(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py
index fd986e2c1..7a6cfdab2 100644
--- a/comfy/text_encoders/lumina2.py
+++ b/comfy/text_encoders/lumina2.py
@@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
-def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
+def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
@@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py
index 1a01b2dd4..50aa4121f 100644
--- a/comfy/text_encoders/omnigen2.py
+++ b/comfy/text_encoders/omnigen2.py
@@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
-def te(dtype_llama=None, llama_scaled_fp8=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py
new file mode 100644
index 000000000..5754424d2
--- /dev/null
+++ b/comfy/text_encoders/ovis.py
@@ -0,0 +1,66 @@
+from transformers import Qwen2Tokenizer
+import comfy.text_encoders.llama
+from comfy import sd1_clip
+import os
+import torch
+import numbers
+
+class Qwen3Tokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
+
+
+class OvisTokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
+ self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+
+ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
+ if llama_template is None:
+ llama_text = self.llama_template.format(text)
+ else:
+ llama_text = llama_template.format(text)
+
+ tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
+ return tokens
+
+class Ovis25_2BModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
+
+
+class OvisTEModel(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
+
+ def encode_token_weights(self, token_weight_pairs, template_end=-1):
+ out, pooled = super().encode_token_weights(token_weight_pairs)
+ tok_pairs = token_weight_pairs["qwen3_2b"][0]
+ count_im_start = 0
+ if template_end == -1:
+ for i, v in enumerate(tok_pairs):
+ elem = v[0]
+ if not torch.is_tensor(elem):
+ if isinstance(elem, numbers.Integral):
+ if elem == 4004 and count_im_start < 1:
+ template_end = i
+ count_im_start += 1
+
+ if out.shape[1] > (template_end + 1):
+ if tok_pairs[template_end + 1][0] == 25:
+ template_end += 1
+
+ out = out[:, template_end:]
+ return out, pooled, {}
+
+
+def te(dtype_llama=None, llama_quantization_metadata=None):
+ class OvisTEModel_(OvisTEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ if llama_quantization_metadata is not None:
+ model_options["quantization_metadata"] = llama_quantization_metadata
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return OvisTEModel_
diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py
index 5f383de07..e5e5f18be 100644
--- a/comfy/text_encoders/pixart_t5.py
+++ b/comfy/text_encoders/pixart_t5.py
@@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
-def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
index 67688e82c..df5b5d7fe 100644
--- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
+++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
@@ -179,36 +179,36 @@
"special": false
},
"151665": {
- "content": "<|img|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151666": {
- "content": "<|endofimg|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151667": {
- "content": "<|meta|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151668": {
- "content": "<|endofmeta|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
}
},
"additional_special_tokens": [
diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py
index c0d32a6ef..5c14dec23 100644
--- a/comfy/text_encoders/qwen_image.py
+++ b/comfy/text_encoders/qwen_image.py
@@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
return out, pooled, extra
-def te(dtype_llama=None, llama_scaled_fp8=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if llama_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
+ model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py
index ff5d412db..8b153c72b 100644
--- a/comfy/text_encoders/sd3_clip.py
+++ b/comfy/text_encoders/sd3_clip.py
@@ -6,14 +6,15 @@ import torch
import os
import comfy.model_management
import logging
+import comfy.utils
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
- t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
- if t5xxl_scaled_fp8 is not None:
+ t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
+ if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["quantization_metadata"] = t5xxl_quantization_metadata
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
- scaled_fp8_key = "{}scaled_fp8".format(prefix)
- if scaled_fp8_key in state_dict:
- out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
+ quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
+ if quant is not None:
+ out["t5_quantization_metadata"] = quant
return out
@@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
-def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
+def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_
diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py
index d50fa4b28..164a57edd 100644
--- a/comfy/text_encoders/wan.py
+++ b/comfy/text_encoders/wan.py
@@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
-def te(dtype_t5=None, t5xxl_scaled_fp8=None):
+def te(dtype_t5=None, t5_quantization_metadata=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ if t5_quantization_metadata is not None:
model_options = model_options.copy()
- model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ model_options["quantization_metadata"] = t5_quantization_metadata
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py
index bb9273b20..19adde0b7 100644
--- a/comfy/text_encoders/z_image.py
+++ b/comfy/text_encoders/z_image.py
@@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
-def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
+def te(dtype_llama=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
- if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
- model_options = model_options.copy()
- model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
diff --git a/comfy/utils.py b/comfy/utils.py
index 4bd281057..9dc0d76ac 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -29,6 +29,7 @@ import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
+import json
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@@ -675,6 +676,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
return key_map
+def z_image_to_diffusers(mmdit_config, output_prefix=""):
+ n_layers = mmdit_config.get("n_layers", 0)
+ hidden_size = mmdit_config.get("dim", 0)
+ n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
+ n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
+ key_map = {}
+
+ def add_block_keys(prefix_from, prefix_to, has_adaln=True):
+ for end in ("weight", "bias"):
+ k = "{}.attention.".format(prefix_from)
+ qkv = "{}.attention.qkv.{}".format(prefix_to, end)
+ key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
+ key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
+ key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
+
+ block_map = {
+ "attention.norm_q.weight": "attention.q_norm.weight",
+ "attention.norm_k.weight": "attention.k_norm.weight",
+ "attention.to_out.0.weight": "attention.out.weight",
+ "attention.to_out.0.bias": "attention.out.bias",
+ "attention_norm1.weight": "attention_norm1.weight",
+ "attention_norm2.weight": "attention_norm2.weight",
+ "feed_forward.w1.weight": "feed_forward.w1.weight",
+ "feed_forward.w2.weight": "feed_forward.w2.weight",
+ "feed_forward.w3.weight": "feed_forward.w3.weight",
+ "ffn_norm1.weight": "ffn_norm1.weight",
+ "ffn_norm2.weight": "ffn_norm2.weight",
+ }
+ if has_adaln:
+ block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
+ block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
+ for k, v in block_map.items():
+ key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
+
+ for i in range(n_layers):
+ add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
+
+ for i in range(n_context_refiner):
+ add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
+
+ for i in range(n_noise_refiner):
+ add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
+
+ MAP_BASIC = [
+ ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
+ ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
+ ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
+ ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
+ ("x_embedder.weight", "all_x_embedder.2-1.weight"),
+ ("x_embedder.bias", "all_x_embedder.2-1.bias"),
+ ("x_pad_token", "x_pad_token"),
+ ("cap_embedder.0.weight", "cap_embedder.0.weight"),
+ ("cap_embedder.1.weight", "cap_embedder.1.weight"),
+ ("cap_embedder.1.bias", "cap_embedder.1.bias"),
+ ("cap_pad_token", "cap_pad_token"),
+ ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
+ ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
+ ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
+ ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
+ ]
+
+ for c, diffusers in MAP_BASIC:
+ key_map[diffusers] = "{}{}".format(output_prefix, c)
+
+ return key_map
+
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)
@@ -736,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None
return f.read(length_of_header)
+ATTR_UNSET={}
+
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
- prev = getattr(obj, attrs[-1])
- setattr(obj, attrs[-1], value)
+ prev = getattr(obj, attrs[-1], ATTR_UNSET)
+ if value is ATTR_UNSET:
+ delattr(obj, attrs[-1])
+ else:
+ setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
@@ -1128,3 +1200,68 @@ def unpack_latents(combined_latent, latent_shapes):
else:
output_tensors = combined_latent
return output_tensors
+
+def detect_layer_quantization(state_dict, prefix):
+ for k in state_dict:
+ if k.startswith(prefix) and k.endswith(".comfy_quant"):
+ logging.info("Found quantization metadata version 1")
+ return {"mixed_ops": True}
+ return None
+
+def convert_old_quants(state_dict, model_prefix="", metadata={}):
+ if metadata is None:
+ metadata = {}
+
+ quant_metadata = None
+ if "_quantization_metadata" not in metadata:
+ scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
+
+ if scaled_fp8_key in state_dict:
+ scaled_fp8_weight = state_dict[scaled_fp8_key]
+ scaled_fp8_dtype = scaled_fp8_weight.dtype
+ if scaled_fp8_dtype == torch.float32:
+ scaled_fp8_dtype = torch.float8_e4m3fn
+
+ if scaled_fp8_weight.nelement() == 2:
+ full_precision_matrix_mult = True
+ else:
+ full_precision_matrix_mult = False
+
+ out_sd = {}
+ layers = {}
+ for k in list(state_dict.keys()):
+ if not k.startswith(model_prefix):
+ out_sd[k] = state_dict[k]
+ continue
+ k_out = k
+ w = state_dict.pop(k)
+ layer = None
+ if k_out.endswith(".scale_weight"):
+ layer = k_out[:-len(".scale_weight")]
+ k_out = "{}.weight_scale".format(layer)
+
+ if layer is not None:
+ layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
+ if full_precision_matrix_mult:
+ layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
+ layers[layer] = layer_conf
+
+ if k_out.endswith(".scale_input"):
+ layer = k_out[:-len(".scale_input")]
+ k_out = "{}.input_scale".format(layer)
+ if w.item() == 1.0:
+ continue
+
+ out_sd[k_out] = w
+
+ state_dict = out_sd
+ quant_metadata = {"layers": layers}
+ else:
+ quant_metadata = json.loads(metadata["_quantization_metadata"])
+
+ if quant_metadata is not None:
+ layers = quant_metadata["layers"]
+ for k, v in layers.items():
+ state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
+
+ return state_dict, metadata
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 176ae36e0..35e1ac853 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -5,11 +5,11 @@ from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
-from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
-from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
-from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
-from . import _io as io
-from . import _ui as ui
+from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
+from ._input_impl import VideoFromFile, VideoFromComponents
+from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
+from . import _io_public as io
+from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@@ -80,7 +80,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
- This should be used to initialize any global resources neeeded by the extension.
+ This should be used to initialize any global resources needed by the extension.
"""
@abstractmethod
diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py
index 87c81d73a..e634a0311 100644
--- a/comfy_api/latest/_input/video_types.py
+++ b/comfy_api/latest/_input/video_types.py
@@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
-from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
+from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index bde37f90a..ea35c6062 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
-from comfy_api.latest._input import AudioInput, VideoInput
+from .._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
-from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
+from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
@@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
- with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
+ extra_kwargs = {}
+ if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
+ extra_kwargs["format"] = format.value
+ with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 79c0722a9..313a5af20 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -4,7 +4,8 @@ import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter
-from dataclasses import asdict, dataclass
+from collections.abc import Iterable
+from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
-from comfy_api.latest._resources import Resources, ResourcesLocal
+from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
@@ -150,6 +151,9 @@ class _IO_V3:
def __init__(self):
pass
+ def validate(self):
+ pass
+
@property
def io_type(self):
return self.Parent.io_type
@@ -182,6 +186,9 @@ class Input(_IO_V3):
def get_io_type(self):
return _StringIOType(self.io_type)
+ def get_all(self) -> list[Input]:
+ return [self]
+
class WidgetInput(Input):
'''
Base class for a V3 Input with widget.
@@ -561,6 +568,8 @@ class Conditioning(ComfyTypeIO):
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
+ time_dim_replace: NotRequired[torch.Tensor]
+ '''Used by Kandinsky5 I2V.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList
@@ -814,13 +823,61 @@ class MultiType:
else:
return super().as_dict()
+@comfytype(io_type="COMFY_MATCHTYPE_V3")
+class MatchType(ComfyTypeIO):
+ class Template:
+ def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
+ self.template_id = template_id
+ # account for syntactic sugar
+ if not isinstance(allowed_types, Iterable):
+ allowed_types = [allowed_types]
+ for t in allowed_types:
+ if not isinstance(t, type):
+ if not isinstance(t, _ComfyType):
+ raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
+ else:
+ if not issubclass(t, _ComfyType):
+ raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
+ self.allowed_types = allowed_types
+
+ def as_dict(self):
+ return {
+ "template_id": self.template_id,
+ "allowed_types": ",".join([t.io_type for t in self.allowed_types]),
+ }
+
+ class Input(Input):
+ def __init__(self, id: str, template: MatchType.Template,
+ display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
+
+ class Output(Output):
+ def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
+ is_output_list=False):
+ super().__init__(id, display_name, tooltip, is_output_list)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
+
class DynamicInput(Input, ABC):
'''
Abstract class for dynamic input registration.
'''
- @abstractmethod
def get_dynamic(self) -> list[Input]:
- ...
+ return []
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ pass
+
class DynamicOutput(Output, ABC):
'''
@@ -830,99 +887,223 @@ class DynamicOutput(Output, ABC):
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
- @abstractmethod
def get_dynamic(self) -> list[Output]:
- ...
+ return []
@comfytype(io_type="COMFY_AUTOGROW_V3")
-class AutogrowDynamic(ComfyTypeI):
- Type = list[Any]
- class Input(DynamicInput):
- def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
- display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
- super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
- self.template_input = template_input
- if min is not None:
- assert(min >= 1)
- if max is not None:
- assert(max >= 1)
+class Autogrow(ComfyTypeI):
+ Type = dict[str, Any]
+ _MaxNames = 100 # NOTE: max 100 names for sanity
+
+ class _AutogrowTemplate:
+ def __init__(self, input: Input):
+ # dynamic inputs are not allowed as the template input
+ assert(not isinstance(input, DynamicInput))
+ self.input = copy.copy(input)
+ if isinstance(self.input, WidgetInput):
+ self.input.force_input = True
+ self.names: list[str] = []
+ self.cached_inputs = {}
+
+ def _create_input(self, input: Input, name: str):
+ new_input = copy.copy(self.input)
+ new_input.id = name
+ return new_input
+
+ def _create_cached_inputs(self):
+ for name in self.names:
+ self.cached_inputs[name] = self._create_input(self.input, name)
+
+ def get_all(self) -> list[Input]:
+ return list(self.cached_inputs.values())
+
+ def as_dict(self):
+ return prune_dict({
+ "input": create_input_dict_v1([self.input]),
+ })
+
+ def validate(self):
+ self.input.validate()
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ real_inputs = []
+ for name, input in self.cached_inputs.items():
+ if name in live_inputs:
+ real_inputs.append(input)
+ add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, real_inputs, curr_prefix)
+
+ class TemplatePrefix(_AutogrowTemplate):
+ def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
+ super().__init__(input)
+ self.prefix = prefix
+ assert(min >= 0)
+ assert(max >= 1)
+ assert(max <= Autogrow._MaxNames)
self.min = min
self.max = max
+ self.names = [f"{self.prefix}{i}" for i in range(self.max)]
+ self._create_cached_inputs()
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "prefix": self.prefix,
+ "min": self.min,
+ "max": self.max,
+ })
+
+ class TemplateNames(_AutogrowTemplate):
+ def __init__(self, input: Input, names: list[str], min: int=1):
+ super().__init__(input)
+ self.names = names[:Autogrow._MaxNames]
+ assert(min >= 0)
+ self.min = min
+ self._create_cached_inputs()
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "names": self.names,
+ "min": self.min,
+ })
+
+ class Input(DynamicInput):
+ def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
+ display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
def get_dynamic(self) -> list[Input]:
- curr_count = 1
- new_inputs = []
- for i in range(self.min):
- new_input = copy.copy(self.template_input)
- new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
- if new_input.display_name is not None:
- new_input.display_name = f"{new_input.display_name}{curr_count}"
- new_input.optional = self.optional or new_input.optional
- if isinstance(self.template_input, WidgetInput):
- new_input.force_input = True
- new_inputs.append(new_input)
- curr_count += 1
- # pretend to expand up to max
- for i in range(curr_count-1, self.max):
- new_input = copy.copy(self.template_input)
- new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
- if new_input.display_name is not None:
- new_input.display_name = f"{new_input.display_name}{curr_count}"
- new_input.optional = True
- if isinstance(self.template_input, WidgetInput):
- new_input.force_input = True
- new_inputs.append(new_input)
- curr_count += 1
- return new_inputs
+ return self.template.get_all()
-@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
-class ComboDynamic(ComfyTypeI):
- class Input(DynamicInput):
- def __init__(self, id: str):
- pass
+ def get_all(self) -> list[Input]:
+ return [self] + self.template.get_all()
-@comfytype(io_type="COMFY_MATCHTYPE_V3")
-class MatchType(ComfyTypeIO):
- class Template:
- def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
- self.template_id = template_id
- self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
+ def validate(self):
+ self.template.validate()
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ curr_prefix = f"{curr_prefix}{self.id}."
+ # need to remove self from expected inputs dictionary; replaced by template inputs in frontend
+ for inner_dict in d.values():
+ if self.id in inner_dict:
+ del inner_dict[self.id]
+ self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
+
+@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
+class DynamicCombo(ComfyTypeI):
+ Type = dict[str, Any]
+
+ class Option:
+ def __init__(self, key: str, inputs: list[Input]):
+ self.key = key
+ self.inputs = inputs
def as_dict(self):
return {
- "template_id": self.template_id,
- "allowed_types": "".join(t.io_type for t in self.allowed_types),
+ "key": self.key,
+ "inputs": create_input_dict_v1(self.inputs),
}
class Input(DynamicInput):
- def __init__(self, id: str, template: MatchType.Template,
+ def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
- self.template = template
+ self.options = options
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ # check if dynamic input's id is in live_inputs
+ if self.id in live_inputs:
+ curr_prefix = f"{curr_prefix}{self.id}."
+ key = live_inputs[self.id]
+ selected_option = None
+ for option in self.options:
+ if option.key == key:
+ selected_option = option
+ break
+ if selected_option is not None:
+ add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
- return [self]
+ return [input for option in self.options for input in option.inputs]
+
+ def get_all(self) -> list[Input]:
+ return [self] + [input for option in self.options for input in option.inputs]
def as_dict(self):
return super().as_dict() | prune_dict({
- "template": self.template.as_dict(),
+ "options": [o.as_dict() for o in self.options],
})
- class Output(DynamicOutput):
- def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
- is_output_list=False):
- super().__init__(id, display_name, tooltip, is_output_list)
- self.template = template
+ def validate(self):
+ # make sure all nested inputs are validated
+ for option in self.options:
+ for input in option.inputs:
+ input.validate()
- def get_dynamic(self) -> list[Output]:
- return [self]
+@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
+class DynamicSlot(ComfyTypeI):
+ Type = dict[str, Any]
+
+ class Input(DynamicInput):
+ def __init__(self, slot: Input, inputs: list[Input],
+ display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ assert(not isinstance(slot, DynamicInput))
+ self.slot = copy.copy(slot)
+ self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
+ optional = True
+ self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
+ self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
+ self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
+ super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
+ self.inputs = inputs
+ self.force_input = None
+ # force widget inputs to have no widgets, otherwise this would be awkward
+ if isinstance(self.slot, WidgetInput):
+ self.force_input = True
+ self.slot.force_input = True
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ if self.id in live_inputs:
+ curr_prefix = f"{curr_prefix}{self.id}."
+ add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
+
+ def get_dynamic(self) -> list[Input]:
+ return [self.slot] + self.inputs
+
+ def get_all(self) -> list[Input]:
+ return [self] + [self.slot] + self.inputs
def as_dict(self):
return super().as_dict() | prune_dict({
- "template": self.template.as_dict(),
+ "slotType": str(self.slot.get_io_type()),
+ "inputs": create_input_dict_v1(self.inputs),
+ "forceInput": self.force_input,
})
+ def validate(self):
+ self.slot.validate()
+ for input in self.inputs:
+ input.validate()
+
+def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
+ dynamic = d.setdefault("dynamic_paths", {})
+ if self is not None:
+ dynamic[self.id] = f"{curr_prefix}{self.id}"
+ for i in inputs:
+ if not isinstance(i, DynamicInput):
+ dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
+
+class V3Data(TypedDict):
+ hidden_inputs: dict[str, Any]
+ dynamic_paths: dict[str, Any]
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
@@ -984,6 +1165,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
+ output_matchtypes: list[str]=None
name: str=None
display_name: str=None
description: str=None
@@ -1019,9 +1201,9 @@ class Schema:
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
- inputs: list[Input]=None
- outputs: list[Output]=None
- hidden: list[Hidden]=None
+ inputs: list[Input] = field(default_factory=list)
+ outputs: list[Output] = field(default_factory=list)
+ hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
@@ -1061,7 +1243,11 @@ class Schema:
'''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other
'''
- input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
+ nested_inputs: list[Input] = []
+ if self.inputs is not None:
+ for input in self.inputs:
+ nested_inputs.extend(input.get_all())
+ input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids)
output_set = set(output_ids)
@@ -1077,6 +1263,13 @@ class Schema:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0:
raise ValueError("\n".join(issues))
+ # validate inputs and outputs
+ if self.inputs is not None:
+ for input in self.inputs:
+ input.validate()
+ if self.outputs is not None:
+ for output in self.outputs:
+ output.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
@@ -1102,19 +1295,10 @@ class Schema:
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
- def get_v1_info(self, cls) -> NodeInfoV1:
+ def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
+ # NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs
- input = {
- "required": {}
- }
- if self.inputs:
- for i in self.inputs:
- if isinstance(i, DynamicInput):
- dynamic_inputs = i.get_dynamic()
- for d in dynamic_inputs:
- add_to_dict_v1(d, input)
- else:
- add_to_dict_v1(i, input)
+ input = create_input_dict_v1(self.inputs, live_inputs)
if self.hidden:
for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@@ -1123,12 +1307,24 @@ class Schema:
output_is_list = []
output_name = []
output_tooltips = []
+ output_matchtypes = []
+ any_matchtypes = False
if self.outputs:
for o in self.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
+ # special handling for MatchType
+ if isinstance(o, MatchType.Output):
+ output_matchtypes.append(o.template.template_id)
+ any_matchtypes = True
+ else:
+ output_matchtypes.append(None)
+
+ # clear out lists that are all None
+ if not any_matchtypes:
+ output_matchtypes = None
info = NodeInfoV1(
input=input,
@@ -1137,6 +1333,7 @@ class Schema:
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
+ output_matchtypes=output_matchtypes,
name=self.node_id,
display_name=self.display_name,
category=self.category,
@@ -1182,16 +1379,57 @@ class Schema:
return info
-def add_to_dict_v1(i: Input, input: dict):
+def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
+ input = {
+ "required": {}
+ }
+ add_to_input_dict_v1(input, inputs, live_inputs)
+ return input
+
+def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
+ for i in inputs:
+ if isinstance(i, DynamicInput):
+ add_to_dict_v1(i, d)
+ if live_inputs is not None:
+ i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
+ else:
+ add_to_dict_v1(i, d)
+
+def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required"
as_dict = i.as_dict()
# for v1, we don't want to include the optional key
as_dict.pop("optional", None)
- input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
+ if dynamic_dict is None:
+ value = (i.get_io_type(), as_dict)
+ else:
+ value = (i.get_io_type(), as_dict, dynamic_dict)
+ d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
+def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
+ paths = v3_data.get("dynamic_paths", None)
+ if paths is None:
+ return values
+ values = values.copy()
+ result = {}
+
+ for key, path in paths.items():
+ parts = path.split(".")
+ current = result
+
+ for i, p in enumerate(parts):
+ is_last = (i == len(parts) - 1)
+
+ if is_last:
+ current[p] = values.pop(key, None)
+ else:
+ current = current.setdefault(p, {})
+
+ values.update(result)
+ return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@@ -1311,12 +1549,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
- def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
+ def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
- type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
+ type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone
@final
@@ -1433,14 +1671,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
- def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
+ def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA()
- info = schema.get_v1_info(cls)
+ info = schema.get_v1_info(cls, live_inputs)
input = info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
- return input, schema
+ v3_data: V3Data = {}
+ dynamic = input.pop("dynamic_paths", None)
+ if dynamic is not None:
+ v3_data["dynamic_paths"] = dynamic
+ return input, schema, v3_data
return input
@final
@@ -1513,7 +1755,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError
@classmethod
- def validate_inputs(cls, **kwargs) -> bool:
+ def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError
@@ -1628,6 +1870,7 @@ __all__ = [
"StyleModel",
"Gligen",
"UpscaleModel",
+ "LatentUpscaleModel",
"Audio",
"Video",
"SVG",
@@ -1651,6 +1894,10 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
+ # Dynamic Types
+ "MatchType",
+ # "DynamicCombo",
+ # "Autogrow",
# Other classes
"HiddenHolder",
"Hidden",
@@ -1661,4 +1908,5 @@ __all__ = [
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
+ "V3Data",
]
diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py
new file mode 100644
index 000000000..43c7680f3
--- /dev/null
+++ b/comfy_api/latest/_io_public.py
@@ -0,0 +1 @@
+from ._io import * # noqa: F403
diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py
index b0bbabe2a..2babe209a 100644
--- a/comfy_api/latest/_ui.py
+++ b/comfy_api/latest/_ui.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import os
import random
+import uuid
from io import BytesIO
from typing import Type
@@ -21,7 +22,7 @@ import folder_paths
# used for image preview
from comfy.cli_args import args
-from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
+from ._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
@@ -318,9 +319,10 @@ class AudioSaveHelper:
for key, value in metadata.items():
output_container.metadata[key] = value
+ layout = "mono" if waveform.shape[0] == 1 else "stereo"
# Set up the output stream with appropriate properties
if format == "opus":
- out_stream = output_container.add_stream("libopus", rate=sample_rate)
+ out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@@ -332,7 +334,7 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
- out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
+ out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@@ -341,12 +343,12 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
- out_stream = output_container.add_stream("flac", rate=sample_rate)
+ out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
- layout="mono" if waveform.shape[0] == 1 else "stereo",
+ layout=layout,
)
frame.sample_rate = sample_rate
frame.pts = 0
@@ -436,9 +438,19 @@ class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
+ self.bg_image_path = None
+ bg_image = kwargs.get("bg_image", None)
+ if bg_image is not None:
+ img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
+ img = PILImage.fromarray(img_array)
+ temp_dir = folder_paths.get_temp_directory()
+ filename = f"bg_{uuid.uuid4().hex}.png"
+ bg_image_path = os.path.join(temp_dir, filename)
+ img.save(bg_image_path, compress_level=1)
+ self.bg_image_path = f"temp/{filename}"
def as_dict(self):
- return {"result": [self.model_file, self.camera_info]}
+ return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewText(_UIOutput):
diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py
new file mode 100644
index 000000000..85b11d78b
--- /dev/null
+++ b/comfy_api/latest/_ui_public.py
@@ -0,0 +1 @@
+from ._ui import * # noqa: F403
diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py
index c3e3d8e3a..fd3b5a510 100644
--- a/comfy_api/latest/_util/video_types.py
+++ b/comfy_api/latest/_util/video_types.py
@@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
-from comfy_api.latest._input import ImageInput, AudioInput
+from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"
diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py
index de0f95001..c4fa1d971 100644
--- a/comfy_api/v0_0_2/__init__.py
+++ b/comfy_api/v0_0_2/__init__.py
@@ -6,7 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
-from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
+from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@@ -42,4 +42,8 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
+ "io",
+ "IO",
+ "ui",
+ "UI",
]
diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py
new file mode 100644
index 000000000..77cd76f9b
--- /dev/null
+++ b/comfy_api_nodes/apis/bytedance_api.py
@@ -0,0 +1,144 @@
+from typing import Literal
+
+from pydantic import BaseModel, Field
+
+
+class Text2ImageTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ prompt: str = Field(...)
+ response_format: str | None = Field("url")
+ size: str | None = Field(None)
+ seed: int | None = Field(0, ge=0, le=2147483647)
+ guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
+ watermark: bool | None = Field(True)
+
+
+class Image2ImageTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ prompt: str = Field(...)
+ response_format: str | None = Field("url")
+ image: str = Field(..., description="Base64 encoded string or image URL")
+ size: str | None = Field("adaptive")
+ seed: int | None = Field(..., ge=0, le=2147483647)
+ guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
+ watermark: bool | None = Field(True)
+
+
+class Seedream4Options(BaseModel):
+ max_images: int = Field(15)
+
+
+class Seedream4TaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ prompt: str = Field(...)
+ response_format: str = Field("url")
+ image: list[str] | None = Field(None, description="Image URLs")
+ size: str = Field(...)
+ seed: int = Field(..., ge=0, le=2147483647)
+ sequential_image_generation: str = Field("disabled")
+ sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
+ watermark: bool = Field(True)
+
+
+class ImageTaskCreationResponse(BaseModel):
+ model: str = Field(...)
+ created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
+ data: list = Field([], description="Contains information about the generated image(s).")
+ error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
+
+
+class TaskTextContent(BaseModel):
+ type: str = Field("text")
+ text: str = Field(...)
+
+
+class TaskImageContentUrl(BaseModel):
+ url: str = Field(...)
+
+
+class TaskImageContent(BaseModel):
+ type: str = Field("image_url")
+ image_url: TaskImageContentUrl = Field(...)
+ role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
+
+
+class Text2VideoTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ content: list[TaskTextContent] = Field(..., min_length=1)
+
+
+class Image2VideoTaskCreationRequest(BaseModel):
+ model: str = Field(...)
+ content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
+
+
+class TaskCreationResponse(BaseModel):
+ id: str = Field(...)
+
+
+class TaskStatusError(BaseModel):
+ code: str = Field(...)
+ message: str = Field(...)
+
+
+class TaskStatusResult(BaseModel):
+ video_url: str = Field(...)
+
+
+class TaskStatusResponse(BaseModel):
+ id: str = Field(...)
+ model: str = Field(...)
+ status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
+ error: TaskStatusError | None = Field(None)
+ content: TaskStatusResult | None = Field(None)
+
+
+RECOMMENDED_PRESETS = [
+ ("1024x1024 (1:1)", 1024, 1024),
+ ("864x1152 (3:4)", 864, 1152),
+ ("1152x864 (4:3)", 1152, 864),
+ ("1280x720 (16:9)", 1280, 720),
+ ("720x1280 (9:16)", 720, 1280),
+ ("832x1248 (2:3)", 832, 1248),
+ ("1248x832 (3:2)", 1248, 832),
+ ("1512x648 (21:9)", 1512, 648),
+ ("2048x2048 (1:1)", 2048, 2048),
+ ("Custom", None, None),
+]
+
+RECOMMENDED_PRESETS_SEEDREAM_4 = [
+ ("2048x2048 (1:1)", 2048, 2048),
+ ("2304x1728 (4:3)", 2304, 1728),
+ ("1728x2304 (3:4)", 1728, 2304),
+ ("2560x1440 (16:9)", 2560, 1440),
+ ("1440x2560 (9:16)", 1440, 2560),
+ ("2496x1664 (3:2)", 2496, 1664),
+ ("1664x2496 (2:3)", 1664, 2496),
+ ("3024x1296 (21:9)", 3024, 1296),
+ ("4096x4096 (1:1)", 4096, 4096),
+ ("Custom", None, None),
+]
+
+# The time in this dictionary are given for 10 seconds duration.
+VIDEO_TASKS_EXECUTION_TIME = {
+ "seedance-1-0-lite-t2v-250428": {
+ "480p": 40,
+ "720p": 60,
+ "1080p": 90,
+ },
+ "seedance-1-0-lite-i2v-250428": {
+ "480p": 40,
+ "720p": 60,
+ "1080p": 90,
+ },
+ "seedance-1-0-pro-250528": {
+ "480p": 70,
+ "720p": 85,
+ "1080p": 115,
+ },
+ "seedance-1-0-pro-fast-251015": {
+ "480p": 50,
+ "720p": 65,
+ "1080p": 100,
+ },
+}
diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py
index d34590d28..f8edc38c9 100644
--- a/comfy_api_nodes/apis/gemini_api.py
+++ b/comfy_api_nodes/apis/gemini_api.py
@@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel):
mimeType: GeminiMimeType | None = Field(None)
+class GeminiFileData(BaseModel):
+ fileUri: str | None = Field(None)
+ mimeType: GeminiMimeType | None = Field(None)
+
+
class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
+ fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
@@ -78,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
- role: GeminiRole = Field(
- ...,
- description="The identity of the entity that creates the message. "
- "The following values are supported: "
- "user: This indicates that the message is sent by a real person, typically a user-generated message. "
- "model: This indicates that the message is generated by the model. "
- "The model value is used to insert messages from model into the conversation during multi-turn conversations. "
- "For non-multi-turn conversations, this field can be left blank or unset.",
- )
+ role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
class GeminiFunctionDeclaration(BaseModel):
diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py
new file mode 100644
index 000000000..d8949f8ac
--- /dev/null
+++ b/comfy_api_nodes/apis/kling_api.py
@@ -0,0 +1,86 @@
+from pydantic import BaseModel, Field
+
+
+class OmniProText2VideoRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class OmniParamImage(BaseModel):
+ image_url: str = Field(...)
+ type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
+
+
+class OmniParamVideo(BaseModel):
+ video_url: str = Field(...)
+ refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
+ keep_original_sound: str = Field(..., description="'yes' or 'no'")
+
+
+class OmniProFirstLastFrameRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class OmniProReferences2VideoRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
+ image_list: list[OmniParamImage] | None = Field(
+ None, max_length=7, description="Max length 4 when video is present."
+ )
+ video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
+ duration: str | None = Field(..., description="From 3 to 10.")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class TaskStatusVideoResult(BaseModel):
+ duration: str | None = Field(None, description="Total video duration")
+ id: str | None = Field(None, description="Generated video ID")
+ url: str | None = Field(None, description="URL for generated video")
+
+
+class TaskStatusImageResult(BaseModel):
+ index: int = Field(..., description="Image Number,0-9")
+ url: str = Field(..., description="URL for generated image")
+
+
+class OmniTaskStatusResults(BaseModel):
+ videos: list[TaskStatusVideoResult] | None = Field(None)
+ images: list[TaskStatusImageResult] | None = Field(None)
+
+
+class OmniTaskStatusResponseData(BaseModel):
+ created_at: int | None = Field(None, description="Task creation time")
+ updated_at: int | None = Field(None, description="Task update time")
+ task_status: str | None = None
+ task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
+ task_id: str | None = Field(None, description="Task ID")
+ task_result: OmniTaskStatusResults | None = Field(None)
+
+
+class OmniTaskStatusResponse(BaseModel):
+ code: int | None = Field(None, description="Error code")
+ message: str | None = Field(None, description="Error message")
+ request_id: str | None = Field(None, description="Request ID")
+ data: OmniTaskStatusResponseData | None = Field(None)
+
+
+class OmniImageParamImage(BaseModel):
+ image: str = Field(...)
+
+
+class OmniProImageRequest(BaseModel):
+ model_name: str = Field(..., description="kling-image-o1")
+ resolution: str = Field(..., description="'1k' or '2k'")
+ aspect_ratio: str | None = Field(...)
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+ n: int | None = Field(1, le=9)
+ image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py
index a55137afb..23ca725b7 100644
--- a/comfy_api_nodes/apis/veo_api.py
+++ b/comfy_api_nodes/apis/veo_api.py
@@ -1,34 +1,21 @@
-from typing import Optional, Union
-from enum import Enum
+from typing import Optional
from pydantic import BaseModel, Field
-class Image2(BaseModel):
- bytesBase64Encoded: str
- gcsUri: Optional[str] = None
- mimeType: Optional[str] = None
+class VeoRequestInstanceImage(BaseModel):
+ bytesBase64Encoded: str | None = Field(None)
+ gcsUri: str | None = Field(None)
+ mimeType: str | None = Field(None)
-class Image3(BaseModel):
- bytesBase64Encoded: Optional[str] = None
- gcsUri: str
- mimeType: Optional[str] = None
-
-
-class Instance1(BaseModel):
- image: Optional[Union[Image2, Image3]] = Field(
- None, description='Optional image to guide video generation'
- )
+class VeoRequestInstance(BaseModel):
+ image: VeoRequestInstanceImage | None = Field(None)
+ lastFrame: VeoRequestInstanceImage | None = Field(None)
prompt: str = Field(..., description='Text description of the video')
-class PersonGeneration1(str, Enum):
- ALLOW = 'ALLOW'
- BLOCK = 'BLOCK'
-
-
-class Parameters1(BaseModel):
+class VeoRequestParameters(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None
@@ -37,17 +24,18 @@ class Parameters1(BaseModel):
description='Generate audio for the video. Only supported by veo 3 models.',
)
negativePrompt: Optional[str] = None
- personGeneration: Optional[PersonGeneration1] = None
+ personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
sampleCount: Optional[int] = None
seed: Optional[int] = None
storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video'
)
+ resolution: str | None = Field(None)
class VeoGenVidRequest(BaseModel):
- instances: Optional[list[Instance1]] = None
- parameters: Optional[Parameters1] = None
+ instances: list[VeoRequestInstance] | None = Field(None)
+ parameters: VeoRequestParameters | None = Field(None)
class VeoGenVidResponse(BaseModel):
@@ -97,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
- videos: Optional[list[Video]] = None
+ videos: Optional[list[Video]] = Field(None)
class VeoGenVidPollResponse(BaseModel):
diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py
index caced471e..57c0218d0 100644
--- a/comfy_api_nodes/nodes_bytedance.py
+++ b/comfy_api_nodes/nodes_bytedance.py
@@ -1,13 +1,27 @@
import logging
import math
-from enum import Enum
-from typing import Literal, Optional, Union
import torch
-from pydantic import BaseModel, Field
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api_nodes.apis.bytedance_api import (
+ RECOMMENDED_PRESETS,
+ RECOMMENDED_PRESETS_SEEDREAM_4,
+ VIDEO_TASKS_EXECUTION_TIME,
+ Image2ImageTaskCreationRequest,
+ Image2VideoTaskCreationRequest,
+ ImageTaskCreationResponse,
+ Seedream4Options,
+ Seedream4TaskCreationRequest,
+ TaskCreationResponse,
+ TaskImageContent,
+ TaskImageContentUrl,
+ TaskStatusResponse,
+ TaskTextContent,
+ Text2ImageTaskCreationRequest,
+ Text2VideoTaskCreationRequest,
+)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
-class Text2ImageModelName(str, Enum):
- seedream_3 = "seedream-3-0-t2i-250415"
-
-
-class Image2ImageModelName(str, Enum):
- seededit_3 = "seededit-3-0-i2i-250628"
-
-
-class Text2VideoModelName(str, Enum):
- seedance_1_pro = "seedance-1-0-pro-250528"
- seedance_1_lite = "seedance-1-0-lite-t2v-250428"
-
-
-class Image2VideoModelName(str, Enum):
- """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
-
- seedance_1_pro = "seedance-1-0-pro-250528"
- seedance_1_lite = "seedance-1-0-lite-i2v-250428"
-
-
-class Text2ImageTaskCreationRequest(BaseModel):
- model: Text2ImageModelName = Text2ImageModelName.seedream_3
- prompt: str = Field(...)
- response_format: Optional[str] = Field("url")
- size: Optional[str] = Field(None)
- seed: Optional[int] = Field(0, ge=0, le=2147483647)
- guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
- watermark: Optional[bool] = Field(True)
-
-
-class Image2ImageTaskCreationRequest(BaseModel):
- model: Image2ImageModelName = Image2ImageModelName.seededit_3
- prompt: str = Field(...)
- response_format: Optional[str] = Field("url")
- image: str = Field(..., description="Base64 encoded string or image URL")
- size: Optional[str] = Field("adaptive")
- seed: Optional[int] = Field(..., ge=0, le=2147483647)
- guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
- watermark: Optional[bool] = Field(True)
-
-
-class Seedream4Options(BaseModel):
- max_images: int = Field(15)
-
-
-class Seedream4TaskCreationRequest(BaseModel):
- model: str = Field("seedream-4-0-250828")
- prompt: str = Field(...)
- response_format: str = Field("url")
- image: Optional[list[str]] = Field(None, description="Image URLs")
- size: str = Field(...)
- seed: int = Field(..., ge=0, le=2147483647)
- sequential_image_generation: str = Field("disabled")
- sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
- watermark: bool = Field(True)
-
-
-class ImageTaskCreationResponse(BaseModel):
- model: str = Field(...)
- created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
- data: list = Field([], description="Contains information about the generated image(s).")
- error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
-
-
-class TaskTextContent(BaseModel):
- type: str = Field("text")
- text: str = Field(...)
-
-
-class TaskImageContentUrl(BaseModel):
- url: str = Field(...)
-
-
-class TaskImageContent(BaseModel):
- type: str = Field("image_url")
- image_url: TaskImageContentUrl = Field(...)
- role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
-
-
-class Text2VideoTaskCreationRequest(BaseModel):
- model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
- content: list[TaskTextContent] = Field(..., min_length=1)
-
-
-class Image2VideoTaskCreationRequest(BaseModel):
- model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
- content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
-
-
-class TaskCreationResponse(BaseModel):
- id: str = Field(...)
-
-
-class TaskStatusError(BaseModel):
- code: str = Field(...)
- message: str = Field(...)
-
-
-class TaskStatusResult(BaseModel):
- video_url: str = Field(...)
-
-
-class TaskStatusResponse(BaseModel):
- id: str = Field(...)
- model: str = Field(...)
- status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
- error: Optional[TaskStatusError] = Field(None)
- content: Optional[TaskStatusResult] = Field(None)
-
-
-RECOMMENDED_PRESETS = [
- ("1024x1024 (1:1)", 1024, 1024),
- ("864x1152 (3:4)", 864, 1152),
- ("1152x864 (4:3)", 1152, 864),
- ("1280x720 (16:9)", 1280, 720),
- ("720x1280 (9:16)", 720, 1280),
- ("832x1248 (2:3)", 832, 1248),
- ("1248x832 (3:2)", 1248, 832),
- ("1512x648 (21:9)", 1512, 648),
- ("2048x2048 (1:1)", 2048, 2048),
- ("Custom", None, None),
-]
-
-RECOMMENDED_PRESETS_SEEDREAM_4 = [
- ("2048x2048 (1:1)", 2048, 2048),
- ("2304x1728 (4:3)", 2304, 1728),
- ("1728x2304 (3:4)", 1728, 2304),
- ("2560x1440 (16:9)", 2560, 1440),
- ("1440x2560 (9:16)", 1440, 2560),
- ("2496x1664 (3:2)", 2496, 1664),
- ("1664x2496 (2:3)", 1664, 2496),
- ("3024x1296 (21:9)", 3024, 1296),
- ("4096x4096 (1:1)", 4096, 4096),
- ("Custom", None, None),
-]
-
-# The time in this dictionary are given for 10 seconds duration.
-VIDEO_TASKS_EXECUTION_TIME = {
- "seedance-1-0-lite-t2v-250428": {
- "480p": 40,
- "720p": 60,
- "1080p": 90,
- },
- "seedance-1-0-lite-i2v-250428": {
- "480p": 40,
- "720p": 60,
- "1080p": 90,
- },
- "seedance-1-0-pro-250528": {
- "480p": 70,
- "720p": 85,
- "1080p": 115,
- },
-}
-
-
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"]
-def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
- """Returns the video URL from the task status response if it exists."""
- if hasattr(response, "content") and response.content:
- return response.content.video_url
- return None
-
-
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
- IO.Combo.Input(
- "model",
- options=Text2ImageModelName,
- default=Text2ImageModelName.seedream_3,
- tooltip="Model name",
- ),
+ IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
- IO.Combo.Input(
- "model",
- options=Image2ImageModelName,
- default=Image2ImageModelName.seededit_3,
- tooltip="Model name",
- ),
+ IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
IO.Image.Input(
"image",
tooltip="The base image to edit",
@@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
async def execute(
cls,
model: str,
- image: torch.Tensor,
+ image: Input.Image,
prompt: str,
seed: int,
guidance_scale: float,
@@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=["seedream-4-0-250828"],
+ options=["seedream-4-5-251128", "seedream-4-0-250828"],
tooltip="Model name",
),
IO.String.Input(
@@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
- step=64,
+ step=8,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
- step=64,
+ step=8,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- image: torch.Tensor = None,
+ image: Input.Image | None = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048,
height: int = 2048,
@@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
+ out_num_pixels = w * h
+ mp_provided = out_num_pixels / 1_000_000.0
+ if "seedream-4-5" in model and out_num_pixels < 3686400:
+ raise ValueError(
+ f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
+ f"but {mp_provided:.2f}MP provided."
+ )
+ if "seedream-4-0" in model and out_num_pixels < 921600:
+ raise ValueError(
+ f"Minimum image resolution that the selected model can generate is 0.92MP, "
+ f"but {mp_provided:.2f}MP provided."
+ )
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=Text2VideoModelName,
- default=Text2VideoModelName.seedance_1_pro,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
+ default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=Image2VideoModelName,
- default=Image2VideoModelName.seedance_1_pro,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
+ default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- image: torch.Tensor,
+ image: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=[model.value for model in Image2VideoModelName],
- default=Image2VideoModelName.seedance_1_lite.value,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
+ default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- first_frame: torch.Tensor,
- last_frame: torch.Tensor,
+ first_frame: Input.Image,
+ last_frame: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
- options=[Image2VideoModelName.seedance_1_lite.value],
- default=Image2VideoModelName.seedance_1_lite.value,
- tooltip="Model name",
+ options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
+ default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
- images: torch.Tensor,
+ images: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task(
cls: type[IO.ComfyNode],
- payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
- estimated_duration: Optional[int],
+ payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
+ estimated_duration: int | None,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
@@ -1085,7 +934,7 @@ async def process_video_task(
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
- return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
+ return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py
index 938a20f84..ad0f4b4d1 100644
--- a/comfy_api_nodes/nodes_gemini.py
+++ b/comfy_api_nodes/nodes_gemini.py
@@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
"""
import base64
-import json
import os
-import time
-import uuid
from enum import Enum
from io import BytesIO
from typing import Literal
@@ -16,10 +13,10 @@ import torch
from typing_extensions import override
import folder_paths
-from comfy_api.latest import IO, ComfyExtension, Input
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
+ GeminiFileData,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiImageConfig,
@@ -29,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType,
GeminiPart,
GeminiRole,
+ GeminiSystemInstructionContent,
+ GeminiTextPart,
Modality,
)
from comfy_api_nodes.util import (
@@ -38,13 +37,21 @@ from comfy_api_nodes.util import (
get_number_of_images,
sync_op,
tensor_to_base64_string,
+ upload_images_to_comfyapi,
validate_string,
video_to_base64_string,
)
-from server import PromptServer
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
+GEMINI_IMAGE_SYS_PROMPT = (
+ "You are an expert image-generation engine. You must ALWAYS produce an image.\n"
+ "Interpret all user input—regardless of "
+ "format, intent, or abstraction—as literal visual directives for image composition.\n"
+ "If a prompt is conversational or lacks specific visual details, "
+ "you must creatively invent a concrete visual scenario that depicts the concept.\n"
+ "Prioritize generating the visual representation above any text, formatting, or conversational requests."
+)
class GeminiModel(str, Enum):
@@ -68,24 +75,43 @@ class GeminiImageModel(str, Enum):
gemini_2_5_flash_image = "gemini-2.5-flash-image"
-def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
- """
- Convert image tensor input to Gemini API compatible parts.
-
- Args:
- image_input: Batch of image tensors from ComfyUI.
-
- Returns:
- List of GeminiPart objects containing the encoded images.
- """
+async def create_image_parts(
+ cls: type[IO.ComfyNode],
+ images: Input.Image,
+ image_limit: int = 0,
+) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
- for image_index in range(image_input.shape[0]):
- image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
+ if image_limit < 0:
+ raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
+ total_images = get_number_of_images(images)
+ if total_images <= 0:
+ raise ValueError("No images provided to create_image_parts; at least one image is required.")
+
+ # If image_limit == 0 --> use all images; otherwise clamp to image_limit.
+ effective_max = total_images if image_limit == 0 else min(total_images, image_limit)
+
+ # Number of images we'll send as URLs (fileData)
+ num_url_images = min(effective_max, 10) # Vertex API max number of image links
+ reference_images_urls = await upload_images_to_comfyapi(
+ cls,
+ images,
+ max_images=num_url_images,
+ )
+ for reference_image_url in reference_images_urls:
+ image_parts.append(
+ GeminiPart(
+ fileData=GeminiFileData(
+ mimeType=GeminiMimeType.image_png,
+ fileUri=reference_image_url,
+ )
+ )
+ )
+ for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
- data=image_as_b64,
+ data=tensor_to_base64_string(images[idx]),
)
)
)
@@ -137,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
-def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
- image_tensors: list[torch.Tensor] = []
+def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
+ image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
@@ -260,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
+ IO.String.Input(
+ "system_prompt",
+ multiline=True,
+ default="",
+ optional=True,
+ tooltip="Foundational instructions that dictate an AI's behavior.",
+ ),
],
outputs=[
IO.String.Output(),
@@ -276,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
- base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
+ base_64_string = video_to_base64_string(
+ video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
+ )
return [
GeminiPart(
inlineData=GeminiInlineData(
@@ -326,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
- images: torch.Tensor | None = None,
+ images: Input.Image | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
+ system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@@ -338,8 +374,7 @@ class GeminiNode(IO.ComfyNode):
# Add other modal parts
if images is not None:
- image_parts = create_image_parts(images)
- parts.extend(image_parts)
+ parts.extend(await create_image_parts(cls, images))
if audio is not None:
parts.extend(cls.create_audio_parts(audio))
if video is not None:
@@ -347,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
if files is not None:
parts.extend(files)
- # Create response
+ gemini_system_prompt = None
+ if system_prompt:
+ gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
+
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@@ -357,36 +395,14 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user,
parts=parts,
)
- ]
+ ],
+ systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_text = get_text_from_response(response)
- if output_text:
- # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
- render_spec = {
- "node_id": cls.hidden.unique_id,
- "component": "ChatHistoryWidget",
- "props": {
- "history": json.dumps(
- [
- {
- "prompt": prompt,
- "response": output_text,
- "response_id": str(uuid.uuid4()),
- "timestamp": time.time(),
- }
- ]
- ),
- },
- }
- PromptServer.instance.send_sync(
- "display_component",
- render_spec,
- )
-
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
@@ -530,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
+ IO.String.Input(
+ "system_prompt",
+ multiline=True,
+ default=GEMINI_IMAGE_SYS_PROMPT,
+ optional=True,
+ tooltip="Foundational instructions that dictate an AI's behavior.",
+ ),
],
outputs=[
IO.Image.Output(),
@@ -549,10 +572,11 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
- images: torch.Tensor | None = None,
+ images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
+ system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@@ -562,11 +586,14 @@ class GeminiImage(IO.ComfyNode):
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None:
- image_parts = create_image_parts(images)
- parts.extend(image_parts)
+ parts.extend(await create_image_parts(cls, images))
if files is not None:
parts.extend(files)
+ gemini_system_prompt = None
+ if system_prompt:
+ gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
+
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@@ -578,34 +605,12 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
),
+ systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
-
- output_text = get_text_from_response(response)
- if output_text:
- render_spec = {
- "node_id": cls.hidden.unique_id,
- "component": "ChatHistoryWidget",
- "props": {
- "history": json.dumps(
- [
- {
- "prompt": prompt,
- "response": output_text,
- "response_id": str(uuid.uuid4()),
- "timestamp": time.time(),
- }
- ]
- ),
- },
- }
- PromptServer.instance.send_sync(
- "display_component",
- render_spec,
- )
- return IO.NodeOutput(get_image_from_response(response), output_text)
+ return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode):
@@ -671,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
+ IO.String.Input(
+ "system_prompt",
+ multiline=True,
+ default=GEMINI_IMAGE_SYS_PROMPT,
+ optional=True,
+ tooltip="Foundational instructions that dictate an AI's behavior.",
+ ),
],
outputs=[
IO.Image.Output(),
@@ -693,8 +705,9 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str,
resolution: str,
response_modalities: str,
- images: torch.Tensor | None = None,
+ images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
+ system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@@ -702,7 +715,7 @@ class GeminiImage2(IO.ComfyNode):
if images is not None:
if get_number_of_images(images) > 14:
raise ValueError("The current maximum number of supported images is 14.")
- parts.extend(create_image_parts(images))
+ parts.extend(await create_image_parts(cls, images))
if files is not None:
parts.extend(files)
@@ -710,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
+ gemini_system_prompt = None
+ if system_prompt:
+ gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
+
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@@ -721,34 +738,12 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
+ systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
-
- output_text = get_text_from_response(response)
- if output_text:
- render_spec = {
- "node_id": cls.hidden.unique_id,
- "component": "ChatHistoryWidget",
- "props": {
- "history": json.dumps(
- [
- {
- "prompt": prompt,
- "response": output_text,
- "response_id": str(uuid.uuid4()),
- "timestamp": time.time(),
- }
- ]
- ),
- },
- }
- PromptServer.instance.send_sync(
- "display_component",
- render_spec,
- )
- return IO.NodeOutput(get_image_from_response(response), output_text)
+ return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 36852038b..6c840dc47 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -4,15 +4,14 @@ For source of truth on the allowed permutations of request fields, please refere
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
"""
-from __future__ import annotations
-from typing import Optional, TypeVar
-import math
import logging
-
-from typing_extensions import override
+import math
+import re
import torch
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
KlingCameraControl,
KlingCameraConfig,
@@ -50,25 +49,33 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName,
KlingSingleImageEffectModelName,
)
+from comfy_api_nodes.apis.kling_api import (
+ OmniImageParamImage,
+ OmniParamImage,
+ OmniParamVideo,
+ OmniProFirstLastFrameRequest,
+ OmniProImageRequest,
+ OmniProReferences2VideoRequest,
+ OmniProText2VideoRequest,
+ OmniTaskStatusResponse,
+)
from comfy_api_nodes.util import (
- validate_image_dimensions,
+ ApiEndpoint,
+ download_url_to_image_tensor,
+ download_url_to_video_output,
+ get_number_of_images,
+ poll_op,
+ sync_op,
+ tensor_to_base64_string,
+ upload_audio_to_comfyapi,
+ upload_images_to_comfyapi,
+ upload_video_to_comfyapi,
validate_image_aspect_ratio,
+ validate_image_dimensions,
+ validate_string,
validate_video_dimensions,
validate_video_duration,
- tensor_to_base64_string,
- validate_string,
- upload_audio_to_comfyapi,
- download_url_to_image_tensor,
- upload_video_to_comfyapi,
- download_url_to_video_output,
- sync_op,
- ApiEndpoint,
- poll_op,
)
-from comfy_api.input_impl import VideoFromFile
-from comfy_api.input.basic_types import AudioInput
-from comfy_api.input.video_types import VideoInput
-from comfy_api.latest import ComfyExtension, IO
KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
@@ -94,8 +101,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32
AVERAGE_DURATION_VIDEO_EFFECTS = 320
AVERAGE_DURATION_VIDEO_EXTEND = 320
-R = TypeVar("R")
-
MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
@@ -130,6 +135,8 @@ MODE_START_END_FRAME = {
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"),
"pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"),
+ "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"),
+ "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"),
}
"""
Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples.
@@ -206,6 +213,50 @@ VOICES_CONFIG = {
}
+def normalize_omni_prompt_references(prompt: str) -> str:
+ """
+ Rewrites Kling Omni-style placeholders used in the app, like:
+
+ @image, @image1, @image2, ... @imageN
+ @video, @video1, @video2, ... @videoN
+
+ into the API-compatible form:
+
+ <<>>, <<>>, ...
+ <<>>, <<>>, ...
+
+ This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
+ """
+ if not prompt:
+ return prompt
+
+ def _image_repl(match):
+ return f"<<>>"
+
+ def _video_repl(match):
+ return f"<<>>"
+
+ # (? and not @imageFoo
+ prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt)
+ return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt)
+
+
+async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
+ response_model=OmniTaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ max_poll_attempts=160,
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
def is_valid_camera_control_configs(configs: list[float]) -> bool:
"""Verifies that at least one camera control configuration is non-zero."""
return any(not math.isclose(value, 0.0) for value in configs)
@@ -296,7 +347,7 @@ def get_video_from_response(response) -> KlingVideoResult:
return video
-def get_video_url_from_response(response) -> Optional[str]:
+def get_video_url_from_response(response) -> str | None:
"""Returns the first video url from the Kling video generation task result.
Will not raise an error if the response is not valid.
"""
@@ -315,7 +366,7 @@ def get_images_from_response(response) -> list[KlingImageResult]:
return images
-def get_images_urls_from_response(response) -> Optional[str]:
+def get_images_urls_from_response(response) -> str | None:
"""Returns the list of image urls from the Kling image generation task result.
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
"""
@@ -349,7 +400,7 @@ async def execute_text2video(
model_mode: str,
duration: str,
aspect_ratio: str,
- camera_control: Optional[KlingCameraControl] = None,
+ camera_control: KlingCameraControl | None = None,
) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
task_creation_response = await sync_op(
@@ -394,8 +445,8 @@ async def execute_image2video(
model_mode: str,
aspect_ratio: str,
duration: str,
- camera_control: Optional[KlingCameraControl] = None,
- end_frame: Optional[torch.Tensor] = None,
+ camera_control: KlingCameraControl | None = None,
+ end_frame: torch.Tensor | None = None,
) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
validate_input_image(start_frame)
@@ -451,9 +502,9 @@ async def execute_video_effect(
model_name: str,
duration: KlingVideoGenDuration,
image_1: torch.Tensor,
- image_2: Optional[torch.Tensor] = None,
- model_mode: Optional[KlingVideoGenMode] = None,
-) -> tuple[VideoFromFile, str, str]:
+ image_2: torch.Tensor | None = None,
+ model_mode: KlingVideoGenMode | None = None,
+) -> tuple[InputImpl.VideoFromFile, str, str]:
if dual_character:
request_input_field = KlingDualCharacterEffectInput(
model_name=model_name,
@@ -499,13 +550,13 @@ async def execute_video_effect(
async def execute_lipsync(
cls: type[IO.ComfyNode],
- video: VideoInput,
- audio: Optional[AudioInput] = None,
- voice_language: Optional[str] = None,
- model_mode: Optional[str] = None,
- text: Optional[str] = None,
- voice_speed: Optional[float] = None,
- voice_id: Optional[str] = None,
+ video: Input.Video,
+ audio: Input.Audio | None = None,
+ voice_language: str | None = None,
+ model_mode: str | None = None,
+ text: str | None = None,
+ voice_speed: float | None = None,
+ voice_id: str | None = None,
) -> IO.NodeOutput:
if text:
validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC)
@@ -740,6 +791,474 @@ class KlingTextToVideoNode(IO.ComfyNode):
)
+class OmniProTextToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProTextToVideoNode",
+ display_name="Kling Omni Text to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use text prompts to generate videos with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Combo.Input("duration", options=[5, 10]),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProText2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProFirstLastFrameNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProFirstLastFrameNode",
+ display_name="Kling Omni First-Last-Frame to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("duration", options=["5", "10"]),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input(
+ "end_frame",
+ optional=True,
+ tooltip="An optional end frame for the video. "
+ "This cannot be used simultaneously with 'reference_images'.",
+ ),
+ IO.Image.Input(
+ "reference_images",
+ optional=True,
+ tooltip="Up to 6 additional reference images.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ duration: int,
+ first_frame: Input.Image,
+ end_frame: Input.Image | None = None,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ if end_frame is not None and reference_images is not None:
+ raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
+ validate_image_dimensions(first_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
+ image_list: list[OmniParamImage] = [
+ OmniParamImage(
+ image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
+ type="first_frame",
+ )
+ ]
+ if end_frame is not None:
+ validate_image_dimensions(end_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
+ image_list.append(
+ OmniParamImage(
+ image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
+ type="end_frame",
+ )
+ )
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 6:
+ raise ValueError("The maximum number of reference images allowed is 6.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
+ image_list.append(OmniParamImage(image_url=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProFirstLastFrameRequest(
+ model_name=model_name,
+ prompt=prompt,
+ duration=str(duration),
+ image_list=image_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProImageToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProImageToVideoNode",
+ display_name="Kling Omni Image to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use up to 7 reference images to generate a video with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 7 reference images.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ reference_images: Input.Image,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ if get_number_of_images(reference_images) > 7:
+ raise ValueError("The maximum number of reference images is 7.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ image_list: list[OmniParamImage] = []
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ image_list=image_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProVideoToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProVideoToVideoNode",
+ display_name="Kling Omni Video to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
+ IO.Boolean.Input("keep_original_sound", default=True),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 4 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ reference_video: Input.Video,
+ keep_original_sound: bool,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
+ validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
+ image_list: list[OmniParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 4:
+ raise ValueError("The maximum number of reference images allowed with a video input is 4.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ video_list = [
+ OmniParamVideo(
+ video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
+ refer_type="feature",
+ keep_original_sound="yes" if keep_original_sound else "no",
+ )
+ ]
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ image_list=image_list if image_list else None,
+ video_list=video_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProEditVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProEditVideoNode",
+ display_name="Kling Omni Edit Video (Pro)",
+ category="api node/video/Kling",
+ description="Edit an existing video with the latest model from Kling.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
+ IO.Boolean.Input("keep_original_sound", default=True),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 4 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ video: Input.Video,
+ keep_original_sound: bool,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_video_duration(video, min_duration=3.0, max_duration=10.05)
+ validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
+ image_list: list[OmniParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 4:
+ raise ValueError("The maximum number of reference images allowed with a video input is 4.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ video_list = [
+ OmniParamVideo(
+ video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
+ refer_type="base",
+ keep_original_sound="yes" if keep_original_sound else "no",
+ )
+ ]
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=None,
+ duration=None,
+ image_list=image_list if image_list else None,
+ video_list=video_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProImageNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProImageNode",
+ display_name="Kling Omni Image (Pro)",
+ category="api node/image/Kling",
+ description="Create or edit images with the latest model from Kling.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-image-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the image content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("resolution", options=["1K", "2K"]),
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
+ ),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 10 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ resolution: str,
+ aspect_ratio: str,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ image_list: list[OmniImageParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 10:
+ raise ValueError("The maximum number of reference images is 10.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniImageParamImage(image=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProImageRequest(
+ model_name=model_name,
+ prompt=prompt,
+ resolution=resolution.lower(),
+ aspect_ratio=aspect_ratio,
+ image_list=image_list if image_list else None,
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
+ response_model=OmniTaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
+
+
class KlingCameraControlT2VNode(IO.ComfyNode):
"""
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@@ -787,7 +1306,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
negative_prompt: str,
cfg_scale: float,
aspect_ratio: str,
- camera_control: Optional[KlingCameraControl] = None,
+ camera_control: KlingCameraControl | None = None,
) -> IO.NodeOutput:
return await execute_text2video(
cls,
@@ -854,8 +1373,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
mode: str,
aspect_ratio: str,
duration: str,
- camera_control: Optional[KlingCameraControl] = None,
- end_frame: Optional[torch.Tensor] = None,
+ camera_control: KlingCameraControl | None = None,
+ end_frame: torch.Tensor | None = None,
) -> IO.NodeOutput:
return await execute_image2video(
cls,
@@ -965,15 +1484,11 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"),
IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0),
- IO.Combo.Input(
- "aspect_ratio",
- options=[i.value for i in KlingVideoGenAspectRatio],
- default="16:9",
- ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input(
"mode",
options=modes,
- default=modes[2],
+ default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@@ -1170,7 +1685,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
category="api node/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[
- IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"),
+ IO.Image.Input(
+ "image",
+ tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
+ ),
IO.Combo.Input(
"effect_scene",
options=[i.value for i in KlingSingleImageEffectsScene],
@@ -1254,8 +1772,8 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- video: VideoInput,
- audio: AudioInput,
+ video: Input.Video,
+ audio: Input.Audio,
voice_language: str,
) -> IO.NodeOutput:
return await execute_lipsync(
@@ -1314,7 +1832,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- video: VideoInput,
+ video: Input.Video,
text: str,
voice: str,
voice_speed: float,
@@ -1471,7 +1989,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
human_fidelity: float,
n: int,
aspect_ratio: KlingImageGenAspectRatio,
- image: Optional[torch.Tensor] = None,
+ image: torch.Tensor | None = None,
) -> IO.NodeOutput:
validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
@@ -1533,6 +2051,12 @@ class KlingExtension(ComfyExtension):
KlingImageGenerationNode,
KlingSingleImageVideoEffectNode,
KlingDualCharacterVideoEffectNode,
+ OmniProTextToVideoNode,
+ OmniProFirstLastFrameNode,
+ OmniProImageToVideoNode,
+ OmniProVideoToVideoNode,
+ OmniProEditVideoNode,
+ # OmniProImageNode, # need support from backend
]
diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py
index 0b757a62b..7e61560dc 100644
--- a/comfy_api_nodes/nodes_ltxv.py
+++ b/comfy_api_nodes/nodes_ltxv.py
@@ -1,12 +1,9 @@
from io import BytesIO
-from typing import Optional
-import torch
from pydantic import BaseModel, Field
from typing_extensions import override
-from comfy_api.input_impl import VideoFromFile
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
@@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
- fps: Optional[int] = Field(25)
- generate_audio: Optional[bool] = Field(True)
- image_uri: Optional[str] = Field(None)
+ fps: int | None = Field(25)
+ generate_audio: bool | None = Field(True)
+ image_uri: str | None = Field(None)
class TextToVideoNode(IO.ComfyNode):
@@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
- return IO.NodeOutput(VideoFromFile(BytesIO(response)))
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
+ image: Input.Image,
model: str,
prompt: str,
duration: int,
@@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
- return IO.NodeOutput(VideoFromFile(BytesIO(response)))
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py
index 7c31d95b3..2771e4790 100644
--- a/comfy_api_nodes/nodes_moonvalley.py
+++ b/comfy_api_nodes/nodes_moonvalley.py
@@ -1,11 +1,8 @@
import logging
-from typing import Optional
-import torch
from typing_extensions import override
-from comfy_api.input import VideoInput
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
@@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)
-def validate_video_to_video_input(video: VideoInput) -> VideoInput:
+def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video)
-def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
+def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
@@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
-def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
+def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
@@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")
-def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
+def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
@@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
+ image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
@@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str,
negative_prompt: str,
seed: int,
- video: Optional[VideoInput] = None,
+ video: Input.Video | None = None,
control_type: str = "Motion Transfer",
- motion_intensity: Optional[int] = 100,
+ motion_intensity: int | None = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py
index acf35d276..c8da5464b 100644
--- a/comfy_api_nodes/nodes_openai.py
+++ b/comfy_api_nodes/nodes_openai.py
@@ -1,15 +1,10 @@
from io import BytesIO
-from typing import Optional, Union
-import json
import os
-import time
-import uuid
from enum import Enum
from inspect import cleandoc
import numpy as np
import torch
from PIL import Image
-from server import PromptServer
import folder_paths
import base64
from comfy_api.latest import IO, ComfyExtension
@@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode):
def create_input_message_contents(
cls,
prompt: str,
- image: Optional[torch.Tensor] = None,
- files: Optional[list[InputFileContent]] = None,
+ image: torch.Tensor | None = None,
+ files: list[InputFileContent] | None = None,
) -> InputMessageContentList:
"""Create a list of input message contents from prompt and optional image."""
- content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [
+ content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
InputTextContent(text=prompt, type="input_text"),
]
if image is not None:
@@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode):
prompt: str,
persist_context: bool = False,
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
- images: Optional[torch.Tensor] = None,
- files: Optional[list[InputFileContent]] = None,
- advanced_options: Optional[CreateModelResponseProperties] = None,
+ images: torch.Tensor | None = None,
+ files: list[InputFileContent] | None = None,
+ advanced_options: CreateModelResponseProperties | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode):
status_extractor=lambda response: response.status,
completed_statuses=["incomplete", "completed"]
)
- output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))
-
- # Update history
- render_spec = {
- "node_id": cls.hidden.unique_id,
- "component": "ChatHistoryWidget",
- "props": {
- "history": json.dumps(
- [
- {
- "prompt": prompt,
- "response": output_text,
- "response_id": str(uuid.uuid4()),
- "timestamp": time.time(),
- }
- ]
- ),
- },
- }
- PromptServer.instance.send_sync(
- "display_component",
- render_spec,
- )
- return IO.NodeOutput(output_text)
+ return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
class OpenAIInputFiles(IO.ComfyNode):
@@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode):
def execute(
cls,
truncation: bool,
- instructions: Optional[str] = None,
- max_output_tokens: Optional[int] = None,
+ instructions: str | None = None,
+ max_output_tokens: int | None = None,
) -> IO.NodeOutput:
"""
Configure advanced options for the OpenAI Chat Node.
diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py
index 51148211b..acd88c391 100644
--- a/comfy_api_nodes/nodes_pika.py
+++ b/comfy_api_nodes/nodes_pika.py
@@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
@@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ is_deprecated=True,
)
@classmethod
diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py
index 2fdafbbfe..3c55039c9 100644
--- a/comfy_api_nodes/nodes_runway.py
+++ b/comfy_api_nodes/nodes_runway.py
@@ -11,12 +11,11 @@ User Guides:
"""
-from typing import Union, Optional
-from typing_extensions import override
from enum import Enum
-import torch
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
@@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
sync_op,
poll_op,
)
-from comfy_api.input_impl import VideoFromFile
-from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768"
-def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
+def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
def extract_progress_from_task_status(
response: TaskStatusResponse,
-) -> Union[float, None]:
+) -> float | None:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
-def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
+def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response(
- cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
+ cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
@@ -119,8 +116,8 @@ async def get_response(
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
- estimated_duration: Optional[int] = None,
-) -> VideoFromFile:
+ estimated_duration: int | None = None,
+) -> InputImpl.VideoFromFile:
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
@@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- start_frame: torch.Tensor,
+ start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- start_frame: torch.Tensor,
+ start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- start_frame: torch.Tensor,
- end_frame: torch.Tensor,
+ start_frame: Input.Image,
+ end_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
cls,
prompt: str,
ratio: str,
- reference_image: Optional[torch.Tensor] = None,
+ reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)
diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py
index d37e9e9b4..e165b8380 100644
--- a/comfy_api_nodes/nodes_veo2.py
+++ b/comfy_api_nodes/nodes_veo2.py
@@ -3,13 +3,15 @@ from io import BytesIO
from typing_extensions import override
-from comfy_api.input_impl.video_types import VideoFromFile
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
VeoGenVidRequest,
VeoGenVidResponse,
+ VeoRequestInstance,
+ VeoRequestInstanceImage,
+ VeoRequestParameters,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@@ -228,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
- return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
@@ -346,12 +348,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
)
+class Veo3FirstLastFrameNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Veo3FirstLastFrameNode",
+ display_name="Google Veo 3 First-Last-Frame to Video",
+ category="api node/video/Veo",
+ description="Generate video using prompt and first and last frames.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Text description of the video",
+ ),
+ IO.String.Input(
+ "negative_prompt",
+ multiline=True,
+ default="",
+ tooltip="Negative text prompt to guide what to avoid in the video",
+ ),
+ IO.Combo.Input("resolution", options=["720p", "1080p"]),
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16"],
+ default="16:9",
+ tooltip="Aspect ratio of the output video",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=8,
+ min=4,
+ max=8,
+ step=2,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=0xFFFFFFFF,
+ step=1,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed for video generation",
+ ),
+ IO.Image.Input("first_frame", tooltip="Start frame"),
+ IO.Image.Input("last_frame", tooltip="End frame"),
+ IO.Combo.Input(
+ "model",
+ options=["veo-3.1-generate", "veo-3.1-fast-generate"],
+ default="veo-3.1-fast-generate",
+ ),
+ IO.Boolean.Input(
+ "generate_audio",
+ default=True,
+ tooltip="Generate audio for the video.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ negative_prompt: str,
+ resolution: str,
+ aspect_ratio: str,
+ duration: int,
+ seed: int,
+ first_frame: Input.Image,
+ last_frame: Input.Image,
+ model: str,
+ generate_audio: bool,
+ ):
+ model = MODELS_MAP[model]
+ initial_response = await sync_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
+ response_model=VeoGenVidResponse,
+ data=VeoGenVidRequest(
+ instances=[
+ VeoRequestInstance(
+ prompt=prompt,
+ image=VeoRequestInstanceImage(
+ bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png"
+ ),
+ lastFrame=VeoRequestInstanceImage(
+ bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png"
+ ),
+ ),
+ ],
+ parameters=VeoRequestParameters(
+ aspectRatio=aspect_ratio,
+ personGeneration="ALLOW",
+ durationSeconds=duration,
+ enhancePrompt=True, # cannot be False for Veo3
+ seed=seed,
+ generateAudio=generate_audio,
+ negativePrompt=negative_prompt,
+ resolution=resolution,
+ ),
+ ),
+ )
+ poll_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
+ response_model=VeoGenVidPollResponse,
+ status_extractor=lambda r: "completed" if r.done else "pending",
+ data=VeoGenVidPollRequest(
+ operationName=initial_response.name,
+ ),
+ poll_interval=5.0,
+ estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
+ )
+
+ if poll_response.error:
+ raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
+
+ response = poll_response.response
+ filtered_count = response.raiMediaFilteredCount
+ if filtered_count:
+ reasons = response.raiMediaFilteredReasons or []
+ reason_part = f": {reasons[0]}" if reasons else ""
+ raise Exception(
+ f"Content blocked by Google's Responsible AI filters{reason_part} "
+ f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
+ )
+
+ if response.videos:
+ video = response.videos[0]
+ if video.bytesBase64Encoded:
+ return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
+ if video.gcsUri:
+ return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
+ raise Exception("Video returned but no data or URL was provided")
+ raise Exception("Video generation completed but no video was returned")
+
+
class VeoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
VeoVideoGenerationNode,
Veo3VideoGenerationNode,
+ Veo3FirstLastFrameNode,
]
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index 80292fb3c..4cc22abfb 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -47,6 +47,7 @@ from .validation_utils import (
validate_string,
validate_video_dimensions,
validate_video_duration,
+ validate_video_frame_count,
)
__all__ = [
@@ -94,6 +95,7 @@ __all__ = [
"validate_string",
"validate_video_dimensions",
"validate_video_duration",
+ "validate_video_frame_count",
# Misc functions
"get_fs_object_size",
]
diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py
index 328fe5227..491e6b6a8 100644
--- a/comfy_api_nodes/util/_helpers.py
+++ b/comfy_api_nodes/util/_helpers.py
@@ -2,8 +2,8 @@ import asyncio
import contextlib
import os
import time
+from collections.abc import Callable
from io import BytesIO
-from typing import Callable, Optional, Union
from comfy.cli_args import args
from comfy.model_management import processing_interrupted
@@ -35,12 +35,12 @@ def default_base_url() -> str:
async def sleep_with_interrupt(
seconds: float,
- node_cls: Optional[type[IO.ComfyNode]],
- label: Optional[str] = None,
- start_ts: Optional[float] = None,
- estimated_total: Optional[int] = None,
+ node_cls: type[IO.ComfyNode] | None,
+ label: str | None = None,
+ start_ts: float | None = None,
+ estimated_total: int | None = None,
*,
- display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
+ display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
):
"""
Sleep in 1s slices while:
@@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()
-def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
+def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index bf01d7d36..bf37cba5f 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -4,10 +4,11 @@ import json
import logging
import time
import uuid
+from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
-from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
+from typing import Any, Literal, TypeVar
from urllib.parse import urljoin, urlparse
import aiohttp
@@ -37,8 +38,8 @@ class ApiEndpoint:
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
- query_params: Optional[dict[str, Any]] = None,
- headers: Optional[dict[str, str]] = None,
+ query_params: dict[str, Any] | None = None,
+ headers: dict[str, str] | None = None,
):
self.path = path
self.method = method
@@ -52,18 +53,18 @@ class _RequestConfig:
endpoint: ApiEndpoint
timeout: float
content_type: str
- data: Optional[dict[str, Any]]
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
- multipart_parser: Optional[Callable]
+ data: dict[str, Any] | None
+ files: dict[str, Any] | list[tuple[str, Any]] | None
+ multipart_parser: Callable | None
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
- estimated_total: Optional[int] = None
- final_label_on_success: Optional[str] = "Completed"
- progress_origin_ts: Optional[float] = None
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
+ estimated_total: int | None = None
+ final_label_on_success: str | None = "Completed"
+ progress_origin_ts: float | None = None
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None
@dataclass
@@ -71,10 +72,10 @@ class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
- price: Optional[float] = None
- estimated_duration: Optional[int] = None
+ price: float | None = None
+ estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
- active_since: Optional[float] = None # start time of current active interval (None if queued)
+ active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@@ -87,20 +88,20 @@ async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
- response_model: Type[M],
- price_extractor: Optional[Callable[[M], Optional[float]]] = None,
- data: Optional[BaseModel] = None,
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
+ response_model: type[M],
+ price_extractor: Callable[[M | Any], float | None] | None = None,
+ data: BaseModel | None = None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
- multipart_parser: Optional[Callable] = None,
+ multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
- estimated_duration: Optional[int] = None,
- final_label_on_success: Optional[str] = "Completed",
- progress_origin_ts: Optional[float] = None,
+ estimated_duration: int | None = None,
+ final_label_on_success: str | None = "Completed",
+ progress_origin_ts: float | None = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
@@ -131,22 +132,22 @@ async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
- response_model: Type[M],
- status_extractor: Callable[[M], Optional[Union[str, int]]],
- progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
- price_extractor: Optional[Callable[[M], Optional[float]]] = None,
- completed_statuses: Optional[list[Union[str, int]]] = None,
- failed_statuses: Optional[list[Union[str, int]]] = None,
- queued_statuses: Optional[list[Union[str, int]]] = None,
- data: Optional[BaseModel] = None,
+ response_model: type[M],
+ status_extractor: Callable[[M | Any], str | int | None],
+ progress_extractor: Callable[[M | Any], int | None] | None = None,
+ price_extractor: Callable[[M | Any], float | None] | None = None,
+ completed_statuses: list[str | int] | None = None,
+ failed_statuses: list[str | int] | None = None,
+ queued_statuses: list[str | int] | None = None,
+ data: BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
- estimated_duration: Optional[int] = None,
- cancel_endpoint: Optional[ApiEndpoint] = None,
+ estimated_duration: int | None = None,
+ cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
@@ -178,22 +179,22 @@ async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
- data: Optional[Union[dict[str, Any], BaseModel]] = None,
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
+ data: dict[str, Any] | BaseModel | None = None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
- multipart_parser: Optional[Callable] = None,
+ multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
- estimated_duration: Optional[int] = None,
+ estimated_duration: int | None = None,
as_binary: bool = False,
- final_label_on_success: Optional[str] = "Completed",
- progress_origin_ts: Optional[float] = None,
+ final_label_on_success: str | None = "Completed",
+ progress_origin_ts: float | None = None,
monitor_progress: bool = True,
-) -> Union[dict[str, Any], bytes]:
+) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON).
@@ -229,21 +230,21 @@ async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
- status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
- progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
- completed_statuses: Optional[list[Union[str, int]]] = None,
- failed_statuses: Optional[list[Union[str, int]]] = None,
- queued_statuses: Optional[list[Union[str, int]]] = None,
- data: Optional[Union[dict[str, Any], BaseModel]] = None,
+ status_extractor: Callable[[dict[str, Any]], str | int | None],
+ progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
+ completed_statuses: list[str | int] | None = None,
+ failed_statuses: list[str | int] | None = None,
+ queued_statuses: list[str | int] | None = None,
+ data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
- estimated_duration: Optional[int] = None,
- cancel_endpoint: Optional[ApiEndpoint] = None,
+ estimated_duration: int | None = None,
+ cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
@@ -261,7 +262,7 @@ async def poll_op_raw(
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
- last_progress: Optional[int] = None
+ last_progress: int | None = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
@@ -420,10 +421,10 @@ async def poll_op_raw(
def _display_text(
node_cls: type[IO.ComfyNode],
- text: Optional[str],
+ text: str | None,
*,
- status: Optional[Union[str, int]] = None,
- price: Optional[float] = None,
+ status: str | int | None = None,
+ price: float | None = None,
) -> None:
display_lines: list[str] = []
if status:
@@ -440,13 +441,13 @@ def _display_text(
def _display_time_progress(
node_cls: type[IO.ComfyNode],
- status: Optional[Union[str, int]],
+ status: str | int | None,
elapsed_seconds: int,
- estimated_total: Optional[int] = None,
+ estimated_total: int | None = None,
*,
- price: Optional[float] = None,
- is_queued: Optional[bool] = None,
- processing_elapsed_seconds: Optional[int] = None,
+ price: float | None = None,
+ is_queued: bool | None = None,
+ processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
-def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
+def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
- data: Optional[dict[str, Any]],
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
-) -> Optional[Union[dict[str, Any], str]]:
+ data: dict[str, Any] | None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None,
+) -> dict[str, Any] | str | None:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
@@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
- final_elapsed_seconds: Optional[int] = None
- extracted_price: Optional[float] = None
+ final_elapsed_seconds: int | None = None
+ extracted_price: float | None = None
while True:
attempt += 1
stop_event = asyncio.Event()
- monitor_task: Optional[asyncio.Task] = None
- sess: Optional[aiohttp.ClientSession] = None
+ monitor_task: asyncio.Task | None = None
+ sess: aiohttp.ClientSession | None = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
)
-def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
+def _validate_or_raise(response_model: type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
@@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
def _wrap_model_extractor(
- response_model: Type[M],
- extractor: Optional[Callable[[M], Any]],
-) -> Optional[Callable[[dict[str, Any]], Any]]:
+ response_model: type[M],
+ extractor: Callable[[M], Any] | None,
+) -> Callable[[dict[str, Any]], Any] | None:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@@ -929,10 +930,10 @@ def _wrap_model_extractor(
return _wrapped
-def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
+def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
if not values:
return set()
- out: set[Union[str, int]] = set()
+ out: set[str | int] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
@@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
return out
-def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
+def _normalize_status_value(val: str | int | None) -> str | int | None:
if isinstance(val, str):
return val.strip().lower()
return val
diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py
index 971dc57de..c57457580 100644
--- a/comfy_api_nodes/util/conversions.py
+++ b/comfy_api_nodes/util/conversions.py
@@ -4,7 +4,6 @@ import math
import mimetypes
import uuid
from io import BytesIO
-from typing import Optional
import av
import numpy as np
@@ -12,8 +11,7 @@ import torch
from PIL import Image
from comfy.utils import common_upscale
-from comfy_api.latest import Input, InputImpl
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import Input, InputImpl, Types
from ._helpers import mimetype_to_extension
@@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio(
image: torch.Tensor,
- name: Optional[str] = None,
+ name: str | None = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
@@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
def video_to_base64_string(
video: Input.Video,
- container_format: VideoContainer = None,
- codec: VideoCodec = None
+ container_format: Types.VideoContainer | None = None,
+ codec: Types.VideoCodec | None = None,
) -> str:
"""
Converts a video input to a base64 string.
@@ -189,12 +187,11 @@ def video_to_base64_string(
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = BytesIO()
-
- # Use provided format/codec if specified, otherwise use video's own if available
- format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
- codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
-
- video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
+ video.save_to(
+ video_bytes_io,
+ format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
+ codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
+ )
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 14207dc68..3e0d0352d 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -3,15 +3,15 @@ import contextlib
import uuid
from io import BytesIO
from pathlib import Path
-from typing import IO, Optional, Union
+from typing import IO
from urllib.parse import urljoin, urlparse
import aiohttp
import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
-from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
+from comfy_api.latest import InputImpl
from . import request_logger
from ._helpers import (
@@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio(
url: str,
- dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
+ dest: BytesIO | IO[bytes] | str | Path | None,
*,
- timeout: Optional[float] = None,
+ timeout: float | None = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
@@ -71,10 +71,10 @@ async def download_url_to_bytesio(
is_path_sink = isinstance(dest, (str, Path))
fhandle = None
- session: Optional[aiohttp.ClientSession] = None
- stop_evt: Optional[asyncio.Event] = None
- monitor_task: Optional[asyncio.Task] = None
- req_task: Optional[asyncio.Task] = None
+ session: aiohttp.ClientSession | None = None
+ stop_evt: asyncio.Event | None = None
+ monitor_task: asyncio.Task | None = None
+ req_task: asyncio.Task | None = None
try:
with contextlib.suppress(Exception):
@@ -234,11 +234,11 @@ async def download_url_to_video_output(
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
-) -> VideoFromFile:
+) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
- return VideoFromFile(result)
+ return InputImpl.VideoFromFile(result)
async def download_url_as_bytesio(
diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py
index ac52e2eab..e0cb4428d 100644
--- a/comfy_api_nodes/util/request_logger.py
+++ b/comfy_api_nodes/util/request_logger.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import datetime
import hashlib
import json
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 632450d9b..b8d33f4d1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -4,15 +4,13 @@ import logging
import time
import uuid
from io import BytesIO
-from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
import torch
from pydantic import BaseModel, Field
-from comfy_api.latest import IO, Input
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import IO, Input, Types
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
@@ -32,7 +30,7 @@ from .conversions import (
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
- content_type: Optional[str] = Field(
+ content_type: str | None = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
@@ -48,22 +46,30 @@ async def upload_images_to_comfyapi(
image: torch.Tensor,
*,
max_images: int = 8,
- mime_type: Optional[str] = None,
- wait_label: Optional[str] = "Uploading",
+ mime_type: str | None = None,
+ wait_label: str | None = "Uploading",
+ show_batch_index: bool = True,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
"""
- # if batch, try to upload each file if max_images is greater than 0
+ # if batched, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
+ num_to_upload = min(batch_len, max_images)
+ batch_start_ts = time.monotonic()
- for idx in range(min(batch_len, max_images)):
+ for idx in range(num_to_upload):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
- url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
+
+ effective_label = wait_label
+ if wait_label and show_batch_index and num_to_upload > 1:
+ effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
+
+ url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
download_urls.append(url)
return download_urls
@@ -92,9 +98,10 @@ async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
- container: VideoContainer = VideoContainer.MP4,
- codec: VideoCodec = VideoCodec.H264,
- max_duration: Optional[int] = None,
+ container: Types.VideoContainer = Types.VideoContainer.MP4,
+ codec: Types.VideoCodec = Types.VideoCodec.H264,
+ max_duration: int | None = None,
+ wait_label: str | None = "Uploading",
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
@@ -119,15 +126,16 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
- return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
+ return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,
filename: str,
- upload_mime_type: Optional[str],
- wait_label: Optional[str] = "Uploading",
+ upload_mime_type: str | None,
+ wait_label: str | None = "Uploading",
+ progress_origin_ts: float | None = None,
) -> str:
"""Uploads a single file to ComfyUI API and returns its download URL."""
if upload_mime_type is None:
@@ -148,6 +156,7 @@ async def upload_file_to_comfyapi(
file_bytes_io,
content_type=upload_mime_type,
wait_label=wait_label,
+ progress_origin_ts=progress_origin_ts,
)
return create_resp.download_url
@@ -155,27 +164,18 @@ async def upload_file_to_comfyapi(
async def upload_file(
cls: type[IO.ComfyNode],
upload_url: str,
- file: Union[BytesIO, str],
+ file: BytesIO | str,
*,
- content_type: Optional[str] = None,
+ content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
- wait_label: Optional[str] = None,
+ wait_label: str | None = None,
+ progress_origin_ts: float | None = None,
) -> None:
"""
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
- Args:
- cls: Node class (provides auth context + UI progress hooks).
- upload_url: Pre-signed PUT URL.
- file: BytesIO or path string.
- content_type: Explicit MIME type. If None, we *suppress* Content-Type.
- max_retries: Maximum retry attempts.
- retry_delay: Initial delay in seconds.
- retry_backoff: Exponential backoff factor.
- wait_label: Progress label shown in Comfy UI.
-
Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
"""
@@ -198,7 +198,7 @@ async def upload_file(
attempt = 0
delay = retry_delay
- start_ts = time.monotonic()
+ start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
op_uuid = uuid.uuid4().hex[:8]
while True:
attempt += 1
@@ -218,7 +218,7 @@ async def upload_file(
return
monitor_task = asyncio.create_task(_monitor())
- sess: Optional[aiohttp.ClientSession] = None
+ sess: aiohttp.ClientSession | None = None
try:
try:
request_logger.log_request_response(
diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py
index ec7006aed..f01edea96 100644
--- a/comfy_api_nodes/util/validation_utils.py
+++ b/comfy_api_nodes/util/validation_utils.py
@@ -1,9 +1,7 @@
import logging
-from typing import Optional
import torch
-from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
@@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
def validate_image_dimensions(
image: torch.Tensor,
- min_width: Optional[int] = None,
- max_width: Optional[int] = None,
- min_height: Optional[int] = None,
- max_height: Optional[int] = None,
+ min_width: int | None = None,
+ max_width: int | None = None,
+ min_height: int | None = None,
+ max_height: int | None = None,
):
height, width = get_image_dimensions(image)
@@ -37,8 +35,8 @@ def validate_image_dimensions(
def validate_image_aspect_ratio(
image: torch.Tensor,
- min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
- max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
+ min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
+ max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
@@ -54,8 +52,8 @@ def validate_image_aspect_ratio(
def validate_images_aspect_ratio_closeness(
first_image: torch.Tensor,
second_image: torch.Tensor,
- min_rel: float, # e.g. 0.8
- max_rel: float, # e.g. 1.25
+ min_rel: float, # e.g. 0.8
+ max_rel: float, # e.g. 1.25
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
def validate_aspect_ratio_string(
aspect_ratio: str,
- min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
- max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
+ min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
+ max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
def validate_video_dimensions(
video: Input.Video,
- min_width: Optional[int] = None,
- max_width: Optional[int] = None,
- min_height: Optional[int] = None,
- max_height: Optional[int] = None,
+ min_width: int | None = None,
+ max_width: int | None = None,
+ min_height: int | None = None,
+ max_height: int | None = None,
):
try:
width, height = video.get_dimensions()
@@ -120,8 +118,8 @@ def validate_video_dimensions(
def validate_video_duration(
video: Input.Video,
- min_duration: Optional[float] = None,
- max_duration: Optional[float] = None,
+ min_duration: float | None = None,
+ max_duration: float | None = None,
):
try:
duration = video.get_duration()
@@ -136,6 +134,23 @@ def validate_video_duration(
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
+def validate_video_frame_count(
+ video: Input.Video,
+ min_frame_count: int | None = None,
+ max_frame_count: int | None = None,
+):
+ try:
+ frame_count = video.get_frame_count()
+ except Exception as e:
+ logging.error("Error getting frame count of video: %s", e)
+ return
+
+ if min_frame_count is not None and min_frame_count > frame_count:
+ raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
+ if max_frame_count is not None and frame_count > max_frame_count:
+ raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
+
+
def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
@@ -144,8 +159,8 @@ def get_number_of_images(images):
def validate_audio_duration(
audio: Input.Audio,
- min_duration: Optional[float] = None,
- max_duration: Optional[float] = None,
+ min_duration: float | None = None,
+ max_duration: float | None = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
@@ -177,7 +192,7 @@ def validate_string(
)
-def validate_container_format_is_mp4(video: VideoInput) -> None:
+def validate_container_format_is_mp4(video: Input.Video) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
@@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
def _assert_ratio_bounds(
ar: float,
*,
- min_ratio: Optional[tuple[float, float]] = None,
- max_ratio: Optional[tuple[float, float]] = None,
+ min_ratio: tuple[float, float] | None = None,
+ max_ratio: tuple[float, float] | None = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py
index cec105fc9..24c0b4ed7 100644
--- a/comfy_execution/validation.py
+++ b/comfy_execution/validation.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+from comfy_api.latest import IO
def validate_node_input(
@@ -23,6 +24,11 @@ def validate_node_input(
if not received_type != input_type:
return True
+ # If the received type or input_type is a MatchType, we can return True immediately;
+ # validation for this is handled by the frontend
+ if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
+ return True
+
# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 2ed7e0b22..c7916443c 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -6,65 +6,80 @@ import torch
import comfy.model_management
import folder_paths
import os
-import io
-import json
-import random
import hashlib
import node_helpers
import logging
-from comfy.cli_args import args
-from comfy.comfy_types import FileLocator
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, IO, UI
-class EmptyLatentAudio:
- def __init__(self):
- self.device = comfy.model_management.intermediate_device()
+class EmptyLatentAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="EmptyLatentAudio",
+ display_name="Empty Latent Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
+ IO.Int.Input(
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
+ ),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
- "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
- }}
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "generate"
-
- CATEGORY = "latent/audio"
-
- def generate(self, seconds, batch_size):
+ def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = round((seconds * 44100 / 2048) / 2) * 2
- latent = torch.zeros([batch_size, 64, length], device=self.device)
- return ({"samples":latent, "type": "audio"}, )
+ latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
+ return IO.NodeOutput({"samples":latent, "type": "audio"})
-class ConditioningStableAudio:
+ generate = execute # TODO: remove
+
+
+class ConditioningStableAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"positive": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
- "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ConditioningStableAudio",
+ category="conditioning",
+ inputs=[
+ IO.Conditioning.Input("positive"),
+ IO.Conditioning.Input("negative"),
+ IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
+ IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
+ ],
+ outputs=[
+ IO.Conditioning.Output(display_name="positive"),
+ IO.Conditioning.Output(display_name="negative"),
+ ],
+ )
- RETURN_TYPES = ("CONDITIONING","CONDITIONING")
- RETURN_NAMES = ("positive", "negative")
-
- FUNCTION = "append"
-
- CATEGORY = "conditioning"
-
- def append(self, positive, negative, seconds_start, seconds_total):
+ @classmethod
+ def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
- return (positive, negative)
+ return IO.NodeOutput(positive, negative)
-class VAEEncodeAudio:
+ append = execute # TODO: remove
+
+
+class VAEEncodeAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "encode"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="VAEEncodeAudio",
+ display_name="VAE Encode Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Vae.Input("vae"),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
- CATEGORY = "latent/audio"
-
- def encode(self, vae, audio):
+ @classmethod
+ def execute(cls, vae, audio) -> IO.NodeOutput:
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
@@ -72,213 +87,134 @@ class VAEEncodeAudio:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
- return ({"samples":t}, )
+ return IO.NodeOutput({"samples":t})
-class VAEDecodeAudio:
+ encode = execute # TODO: remove
+
+
+class VAEDecodeAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "decode"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="VAEDecodeAudio",
+ display_name="VAE Decode Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Latent.Input("samples"),
+ IO.Vae.Input("vae"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "latent/audio"
-
- def decode(self, vae, samples):
+ @classmethod
+ def execute(cls, vae, samples) -> IO.NodeOutput:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
- return ({"waveform": audio, "sample_rate": 44100}, )
+ return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
+
+ decode = execute # TODO: remove
-def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
-
- filename_prefix += self.prefix_append
- full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
- results: list[FileLocator] = []
-
- # Prepare metadata dictionary
- metadata = {}
- if not args.disable_metadata:
- if prompt is not None:
- metadata["prompt"] = json.dumps(prompt)
- if extra_pnginfo is not None:
- for x in extra_pnginfo:
- metadata[x] = json.dumps(extra_pnginfo[x])
-
- # Opus supported sample rates
- OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
-
- for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
- filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
- file = f"{filename_with_batch_num}_{counter:05}_.{format}"
- output_path = os.path.join(full_output_folder, file)
-
- # Use original sample rate initially
- sample_rate = audio["sample_rate"]
-
- # Handle Opus sample rate requirements
- if format == "opus":
- if sample_rate > 48000:
- sample_rate = 48000
- elif sample_rate not in OPUS_RATES:
- # Find the next highest supported rate
- for rate in sorted(OPUS_RATES):
- if rate > sample_rate:
- sample_rate = rate
- break
- if sample_rate not in OPUS_RATES: # Fallback if still not supported
- sample_rate = 48000
-
- # Resample if necessary
- if sample_rate != audio["sample_rate"]:
- waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
-
- # Create output with specified format
- output_buffer = io.BytesIO()
- output_container = av.open(output_buffer, mode='w', format=format)
-
- # Set metadata on the container
- for key, value in metadata.items():
- output_container.metadata[key] = value
-
- layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
- # Set up the output stream with appropriate properties
- if format == "opus":
- out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
- if quality == "64k":
- out_stream.bit_rate = 64000
- elif quality == "96k":
- out_stream.bit_rate = 96000
- elif quality == "128k":
- out_stream.bit_rate = 128000
- elif quality == "192k":
- out_stream.bit_rate = 192000
- elif quality == "320k":
- out_stream.bit_rate = 320000
- elif format == "mp3":
- out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
- if quality == "V0":
- #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
- out_stream.codec_context.qscale = 1
- elif quality == "128k":
- out_stream.bit_rate = 128000
- elif quality == "320k":
- out_stream.bit_rate = 320000
- else: #format == "flac":
- out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
-
- frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
- frame.sample_rate = sample_rate
- frame.pts = 0
- output_container.mux(out_stream.encode(frame))
-
- # Flush encoder
- output_container.mux(out_stream.encode(None))
-
- # Close containers
- output_container.close()
-
- # Write the output to file
- output_buffer.seek(0)
- with open(output_path, 'wb') as f:
- f.write(output_buffer.getbuffer())
-
- results.append({
- "filename": file,
- "subfolder": subfolder,
- "type": self.type
- })
- counter += 1
-
- return { "ui": { "audio": results } }
-
-class SaveAudio:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudio",
+ display_name="Save Audio (FLAC)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_flac"
+ save_flac = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
-
-class SaveAudioMP3:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudioMP3(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudioMP3",
+ display_name="Save Audio (MP3)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- "quality": (["V0", "128k", "320k"], {"default": "V0"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
+ )
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_mp3"
+ save_mp3 = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
-
-class SaveAudioOpus:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudioOpus(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudioOpus",
+ display_name="Save Audio (Opus)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
+ )
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_opus"
+ save_opus = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
-
-class PreviewAudio(SaveAudio):
- def __init__(self):
- self.output_dir = folder_paths.get_temp_directory()
- self.type = "temp"
- self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
+class PreviewAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="PreviewAudio",
+ display_name="Preview Audio",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"audio": ("AUDIO", ), },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio) -> IO.NodeOutput:
+ return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
+
+ save_flac = execute # TODO: remove
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
@@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
wav = f32_pcm(wav)
return wav, sr
-class LoadAudio:
+class LoadAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
+ def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
- return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
+ return IO.Schema(
+ node_id="LoadAudio",
+ display_name="Load Audio",
+ category="audio",
+ inputs=[
+ IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "audio"
-
- RETURN_TYPES = ("AUDIO", )
- FUNCTION = "load"
-
- def load(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
- return (audio, )
+ return IO.NodeOutput(audio)
@classmethod
- def IS_CHANGED(s, audio):
+ def fingerprint_inputs(cls, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
@@ -343,46 +283,69 @@ class LoadAudio:
return m.digest().hex()
@classmethod
- def VALIDATE_INPUTS(s, audio):
+ def validate_inputs(cls, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
-class RecordAudio:
+ load = execute # TODO: remove
+
+
+class RecordAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"audio": ("AUDIO_RECORD", {})}}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="RecordAudio",
+ display_name="Record Audio",
+ category="audio",
+ inputs=[
+ IO.Custom("AUDIO_RECORD").Input("audio"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "audio"
-
- RETURN_TYPES = ("AUDIO", )
- FUNCTION = "load"
-
- def load(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
- return (audio, )
+ return IO.NodeOutput(audio)
+
+ load = execute # TODO: remove
-class TrimAudioDuration:
+class TrimAudioDuration(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "audio": ("AUDIO",),
- "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
- "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="TrimAudioDuration",
+ display_name="Trim Audio Duration",
+ description="Trim audio tensor into chosen time range.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Float.Input(
+ "start_index",
+ default=0.0,
+ min=-0xffffffffffffffff,
+ max=0xffffffffffffffff,
+ step=0.01,
+ tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
+ ),
+ IO.Float.Input(
+ "duration",
+ default=60.0,
+ min=0.0,
+ step=0.01,
+ tooltip="Duration in seconds",
+ ),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- FUNCTION = "trim"
- RETURN_TYPES = ("AUDIO",)
- CATEGORY = "audio"
- DESCRIPTION = "Trim audio tensor into chosen time range."
-
- def trim(self, audio, start_index, duration):
+ @classmethod
+ def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
@@ -399,23 +362,30 @@ class TrimAudioDuration:
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
- return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
+
+ trim = execute # TODO: remove
-class SplitAudioChannels:
+class SplitAudioChannels(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio": ("AUDIO",),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SplitAudioChannels",
+ display_name="Split Audio Channels",
+ description="Separates the audio into left and right channels.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ ],
+ outputs=[
+ IO.Audio.Output(display_name="left"),
+ IO.Audio.Output(display_name="right"),
+ ],
+ )
- RETURN_TYPES = ("AUDIO", "AUDIO")
- RETURN_NAMES = ("left", "right")
- FUNCTION = "separate"
- CATEGORY = "audio"
- DESCRIPTION = "Separates the audio into left and right channels."
-
- def separate(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
@@ -425,7 +395,9 @@ class SplitAudioChannels:
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
- return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
+ return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
+
+ separate = execute # TODO: remove
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
@@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
return waveform_1, waveform_2, output_sample_rate
-class AudioConcat:
+class AudioConcat(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio1": ("AUDIO",),
- "audio2": ("AUDIO",),
- "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioConcat",
+ display_name="Audio Concat",
+ description="Concatenates the audio1 to audio2 in the specified direction.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio1"),
+ IO.Audio.Input("audio2"),
+ IO.Combo.Input(
+ "direction",
+ options=['after', 'before'],
+ default="after",
+ tooltip="Whether to append audio2 after or before audio1.",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "concat"
- CATEGORY = "audio"
- DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
-
- def concat(self, audio1, audio2, direction):
+ @classmethod
+ def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -477,26 +457,33 @@ class AudioConcat:
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
- return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
+ return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
+
+ concat = execute # TODO: remove
-class AudioMerge:
+class AudioMerge(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "audio1": ("AUDIO",),
- "audio2": ("AUDIO",),
- "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioMerge",
+ display_name="Audio Merge",
+ description="Combine two audio tracks by overlaying their waveforms.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio1"),
+ IO.Audio.Input("audio2"),
+ IO.Combo.Input(
+ "merge_method",
+ options=["add", "mean", "subtract", "multiply"],
+ tooltip="The method used to combine the audio waveforms.",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- FUNCTION = "merge"
- RETURN_TYPES = ("AUDIO",)
- CATEGORY = "audio"
- DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
-
- def merge(self, audio1, audio2, merge_method):
+ @classmethod
+ def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -530,85 +517,110 @@ class AudioMerge:
if max_val > 1.0:
waveform = waveform / max_val
- return ({"waveform": waveform, "sample_rate": output_sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
+
+ merge = execute # TODO: remove
-class AudioAdjustVolume:
+class AudioAdjustVolume(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio": ("AUDIO",),
- "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioAdjustVolume",
+ display_name="Audio Adjust Volume",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Int.Input(
+ "volume",
+ default=1,
+ min=-100,
+ max=100,
+ tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "adjust_volume"
- CATEGORY = "audio"
-
- def adjust_volume(self, audio, volume):
+ @classmethod
+ def execute(cls, audio, volume) -> IO.NodeOutput:
if volume == 0:
- return (audio,)
+ return IO.NodeOutput(audio)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
- return ({"waveform": waveform, "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
+
+ adjust_volume = execute # TODO: remove
-class EmptyAudio:
+class EmptyAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
- "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
- "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="EmptyAudio",
+ display_name="Empty Audio",
+ category="audio",
+ inputs=[
+ IO.Float.Input(
+ "duration",
+ default=60.0,
+ min=0.0,
+ max=0xffffffffffffffff,
+ step=0.01,
+ tooltip="Duration of the empty audio clip in seconds",
+ ),
+ IO.Int.Input(
+ "sample_rate",
+ default=44100,
+ tooltip="Sample rate of the empty audio clip.",
+ min=1,
+ max=192000,
+ ),
+ IO.Int.Input(
+ "channels",
+ default=2,
+ min=1,
+ max=2,
+ tooltip="Number of audio channels (1 for mono, 2 for stereo).",
+ ),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "create_empty_audio"
- CATEGORY = "audio"
-
- def create_empty_audio(self, duration, sample_rate, channels):
+ @classmethod
+ def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
- return ({"waveform": waveform, "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
+
+ create_empty_audio = execute # TODO: remove
-NODE_CLASS_MAPPINGS = {
- "EmptyLatentAudio": EmptyLatentAudio,
- "VAEEncodeAudio": VAEEncodeAudio,
- "VAEDecodeAudio": VAEDecodeAudio,
- "SaveAudio": SaveAudio,
- "SaveAudioMP3": SaveAudioMP3,
- "SaveAudioOpus": SaveAudioOpus,
- "LoadAudio": LoadAudio,
- "PreviewAudio": PreviewAudio,
- "ConditioningStableAudio": ConditioningStableAudio,
- "RecordAudio": RecordAudio,
- "TrimAudioDuration": TrimAudioDuration,
- "SplitAudioChannels": SplitAudioChannels,
- "AudioConcat": AudioConcat,
- "AudioMerge": AudioMerge,
- "AudioAdjustVolume": AudioAdjustVolume,
- "EmptyAudio": EmptyAudio,
-}
+class AudioExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ EmptyLatentAudio,
+ VAEEncodeAudio,
+ VAEDecodeAudio,
+ SaveAudio,
+ SaveAudioMP3,
+ SaveAudioOpus,
+ LoadAudio,
+ PreviewAudio,
+ ConditioningStableAudio,
+ RecordAudio,
+ TrimAudioDuration,
+ SplitAudioChannels,
+ AudioConcat,
+ AudioMerge,
+ AudioAdjustVolume,
+ EmptyAudio,
+ ]
-NODE_DISPLAY_NAME_MAPPINGS = {
- "EmptyLatentAudio": "Empty Latent Audio",
- "VAEEncodeAudio": "VAE Encode Audio",
- "VAEDecodeAudio": "VAE Decode Audio",
- "PreviewAudio": "Preview Audio",
- "LoadAudio": "Load Audio",
- "SaveAudio": "Save Audio (FLAC)",
- "SaveAudioMP3": "Save Audio (MP3)",
- "SaveAudioOpus": "Save Audio (Opus)",
- "RecordAudio": "Record Audio",
- "TrimAudioDuration": "Trim Audio Duration",
- "SplitAudioChannels": "Split Audio Channels",
- "AudioConcat": "Audio Concat",
- "AudioMerge": "Audio Merge",
- "AudioAdjustVolume": "Audio Adjust Volume",
- "EmptyAudio": "Empty Audio",
-}
+async def comfy_entrypoint() -> AudioExtension:
+ return AudioExtension()
diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py
index 1c3d9e697..3799a9004 100644
--- a/comfy_extras/nodes_context_windows.py
+++ b/comfy_extras/nodes_context_windows.py
@@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
+ io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
+ #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
+ #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),
@@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
)
@classmethod
- def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
+ cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
context_overlap=context_overlap,
context_stride=context_stride,
closed_loop=closed_loop,
- dim=dim)
+ dim=dim,
+ freenoise=freenoise,
+ cond_retain_index_list=cond_retain_index_list,
+ split_conds_to_windows=split_conds_to_windows
+ )
# make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model)
+ if freenoise: # no other use for this wrapper at this time
+ comfy.context_windows.create_sampler_sample_wrapper(model)
return io.NodeOutput(model)
class WanContextWindowsManualNode(ContextWindowsManualNode):
@@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
+ io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
+ #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
+ #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
]
return schema
@classmethod
- def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
+ cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
- return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
+ return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension):
diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py
index d011f433b..fbb080886 100644
--- a/comfy_extras/nodes_custom_sampler.py
+++ b/comfy_extras/nodes_custom_sampler.py
@@ -3,272 +3,312 @@ import comfy.samplers
import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.k_diffusion import sa_solver
-from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
import latent_preview
import torch
import comfy.utils
import node_helpers
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, io
-class BasicScheduler:
+class BasicScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
- "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="BasicScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Model.Input("model"),
+ io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
+ io.Int.Input("steps", default=20, min=1, max=10000),
+ io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, model, scheduler, steps, denoise):
+ @classmethod
+ def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput:
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
- return (torch.FloatTensor([]),)
+ return io.NodeOutput(torch.FloatTensor([]))
total_steps = int(steps/denoise)
sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
- return (sigmas, )
+ return io.NodeOutput(sigmas)
+
+ get_sigmas = execute
-class KarrasScheduler:
+class KarrasScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="KarrasScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Int.Input("steps", default=20, min=1, max=10000),
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, steps, sigma_max, sigma_min, rho):
+ @classmethod
+ def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput:
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
- return (sigmas, )
+ return io.NodeOutput(sigmas)
-class ExponentialScheduler:
+ get_sigmas = execute
+
+class ExponentialScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="ExponentialScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Int.Input("steps", default=20, min=1, max=10000),
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, steps, sigma_max, sigma_min):
+ @classmethod
+ def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput:
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max)
- return (sigmas, )
+ return io.NodeOutput(sigmas)
-class PolyexponentialScheduler:
+ get_sigmas = execute
+
+class PolyexponentialScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="PolyexponentialScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Int.Input("steps", default=20, min=1, max=10000),
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, steps, sigma_max, sigma_min, rho):
+ @classmethod
+ def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
- return (sigmas, )
+ return io.NodeOutput(sigmas)
-class LaplaceScheduler:
+ get_sigmas = execute
+
+class LaplaceScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
- "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="LaplaceScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Int.Input("steps", default=20, min=1, max=10000),
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False),
+ io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
+ @classmethod
+ def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput:
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
- return (sigmas, )
+ return io.NodeOutput(sigmas)
+
+ get_sigmas = execute
-class SDTurboScheduler:
+class SDTurboScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "steps": ("INT", {"default": 1, "min": 1, "max": 10}),
- "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SDTurboScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Model.Input("model"),
+ io.Int.Input("steps", default=1, min=1, max=10),
+ io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, model, steps, denoise):
+ @classmethod
+ def execute(cls, model, steps, denoise) -> io.NodeOutput:
start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
sigmas = model.get_model_object("model_sampling").sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
- return (sigmas, )
+ return io.NodeOutput(sigmas)
-class BetaSamplingScheduler:
+ get_sigmas = execute
+
+class BetaSamplingScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
- "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="BetaSamplingScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Model.Input("model"),
+ io.Int.Input("steps", default=20, min=1, max=10000),
+ io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
+ io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, model, steps, alpha, beta):
+ @classmethod
+ def execute(cls, model, steps, alpha, beta) -> io.NodeOutput:
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
- return (sigmas, )
+ return io.NodeOutput(sigmas)
-class VPScheduler:
+ get_sigmas = execute
+
+class VPScheduler(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
- "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values
- "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
- "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/schedulers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="VPScheduler",
+ category="sampling/custom_sampling/schedulers",
+ inputs=[
+ io.Int.Input("steps", default=20, min=1, max=10000),
+ io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values
+ io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False),
+ io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, steps, beta_d, beta_min, eps_s):
+ @classmethod
+ def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput:
sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
- return (sigmas, )
+ return io.NodeOutput(sigmas)
-class SplitSigmas:
+ get_sigmas = execute
+
+class SplitSigmas(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"sigmas": ("SIGMAS", ),
- "step": ("INT", {"default": 0, "min": 0, "max": 10000}),
- }
- }
- RETURN_TYPES = ("SIGMAS","SIGMAS")
- RETURN_NAMES = ("high_sigmas", "low_sigmas")
- CATEGORY = "sampling/custom_sampling/sigmas"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SplitSigmas",
+ category="sampling/custom_sampling/sigmas",
+ inputs=[
+ io.Sigmas.Input("sigmas"),
+ io.Int.Input("step", default=0, min=0, max=10000),
+ ],
+ outputs=[
+ io.Sigmas.Output(display_name="high_sigmas"),
+ io.Sigmas.Output(display_name="low_sigmas"),
+ ]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, sigmas, step):
+ @classmethod
+ def execute(cls, sigmas, step) -> io.NodeOutput:
sigmas1 = sigmas[:step + 1]
sigmas2 = sigmas[step:]
- return (sigmas1, sigmas2)
+ return io.NodeOutput(sigmas1, sigmas2)
-class SplitSigmasDenoise:
+ get_sigmas = execute
+
+class SplitSigmasDenoise(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"sigmas": ("SIGMAS", ),
- "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
- }
- }
- RETURN_TYPES = ("SIGMAS","SIGMAS")
- RETURN_NAMES = ("high_sigmas", "low_sigmas")
- CATEGORY = "sampling/custom_sampling/sigmas"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SplitSigmasDenoise",
+ category="sampling/custom_sampling/sigmas",
+ inputs=[
+ io.Sigmas.Input("sigmas"),
+ io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
+ ],
+ outputs=[
+ io.Sigmas.Output(display_name="high_sigmas"),
+ io.Sigmas.Output(display_name="low_sigmas"),
+ ]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, sigmas, denoise):
+ @classmethod
+ def execute(cls, sigmas, denoise) -> io.NodeOutput:
steps = max(sigmas.shape[-1] - 1, 0)
total_steps = round(steps * denoise)
sigmas1 = sigmas[:-(total_steps)]
sigmas2 = sigmas[-(total_steps + 1):]
- return (sigmas1, sigmas2)
+ return io.NodeOutput(sigmas1, sigmas2)
-class FlipSigmas:
+ get_sigmas = execute
+
+class FlipSigmas(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"sigmas": ("SIGMAS", ),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/sigmas"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="FlipSigmas",
+ category="sampling/custom_sampling/sigmas",
+ inputs=[io.Sigmas.Input("sigmas")],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "get_sigmas"
-
- def get_sigmas(self, sigmas):
+ @classmethod
+ def execute(cls, sigmas) -> io.NodeOutput:
if len(sigmas) == 0:
- return (sigmas,)
+ return io.NodeOutput(sigmas)
sigmas = sigmas.flip(0)
if sigmas[0] == 0:
sigmas[0] = 0.0001
- return (sigmas,)
+ return io.NodeOutput(sigmas)
-class SetFirstSigma:
+ get_sigmas = execute
+
+class SetFirstSigma(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"sigmas": ("SIGMAS", ),
- "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/sigmas"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SetFirstSigma",
+ category="sampling/custom_sampling/sigmas",
+ inputs=[
+ io.Sigmas.Input("sigmas"),
+ io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "set_first_sigma"
-
- def set_first_sigma(self, sigmas, sigma):
+ @classmethod
+ def execute(cls, sigmas, sigma) -> io.NodeOutput:
sigmas = sigmas.clone()
sigmas[0] = sigma
- return (sigmas, )
+ return io.NodeOutput(sigmas)
-class ExtendIntermediateSigmas:
+ set_first_sigma = execute
+
+class ExtendIntermediateSigmas(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"sigmas": ("SIGMAS", ),
- "steps": ("INT", {"default": 2, "min": 1, "max": 100}),
- "start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
- "end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
- "spacing": (['linear', 'cosine', 'sine'],),
- }
- }
- RETURN_TYPES = ("SIGMAS",)
- CATEGORY = "sampling/custom_sampling/sigmas"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="ExtendIntermediateSigmas",
+ category="sampling/custom_sampling/sigmas",
+ inputs=[
+ io.Sigmas.Input("sigmas"),
+ io.Int.Input("steps", default=2, min=1, max=100),
+ io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False),
+ io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False),
+ io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']),
+ ],
+ outputs=[io.Sigmas.Output()]
+ )
- FUNCTION = "extend"
-
- def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
+ @classmethod
+ def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput:
if start_at_sigma < 0:
start_at_sigma = float("inf")
@@ -299,27 +339,27 @@ class ExtendIntermediateSigmas:
extended_sigmas = torch.FloatTensor(extended_sigmas)
- return (extended_sigmas,)
+ return io.NodeOutput(extended_sigmas)
+
+ extend = execute
-class SamplingPercentToSigma:
+class SamplingPercentToSigma(io.ComfyNode):
@classmethod
- def INPUT_TYPES(cls) -> InputTypeDict:
- return {
- "required": {
- "model": (IO.MODEL, {}),
- "sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}),
- "return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplingPercentToSigma",
+ category="sampling/custom_sampling/sigmas",
+ inputs=[
+ io.Model.Input("model"),
+ io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
+ io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."),
+ ],
+ outputs=[io.Float.Output(display_name="sigma_value")]
+ )
- RETURN_TYPES = (IO.FLOAT,)
- RETURN_NAMES = ("sigma_value",)
- CATEGORY = "sampling/custom_sampling/sigmas"
-
- FUNCTION = "get_sigma"
-
- def get_sigma(self, model, sampling_percent, return_actual_sigma):
+ @classmethod
+ def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput:
model_sampling = model.get_model_object("model_sampling")
sigma_val = model_sampling.percent_to_sigma(sampling_percent)
if return_actual_sigma:
@@ -327,212 +367,234 @@ class SamplingPercentToSigma:
sigma_val = model_sampling.sigma_max.item()
elif sampling_percent == 1.0:
sigma_val = model_sampling.sigma_min.item()
- return (sigma_val,)
+ return io.NodeOutput(sigma_val)
+
+ get_sigma = execute
-class KSamplerSelect:
+class KSamplerSelect(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="KSamplerSelect",
+ category="sampling/custom_sampling/samplers",
+ inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, sampler_name):
+ @classmethod
+ def execute(cls, sampler_name) -> io.NodeOutput:
sampler = comfy.samplers.sampler_object(sampler_name)
- return (sampler, )
+ return io.NodeOutput(sampler)
-class SamplerDPMPP_3M_SDE:
+ get_sampler = execute
+
+class SamplerDPMPP_3M_SDE(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "noise_device": (['gpu', 'cpu'], ),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerDPMPP_3M_SDE",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Combo.Input("noise_device", options=['gpu', 'cpu']),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, eta, s_noise, noise_device):
+ @classmethod
+ def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput:
if noise_device == 'cpu':
sampler_name = "dpmpp_3m_sde"
else:
sampler_name = "dpmpp_3m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
- return (sampler, )
+ return io.NodeOutput(sampler)
-class SamplerDPMPP_2M_SDE:
+ get_sampler = execute
+
+class SamplerDPMPP_2M_SDE(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"solver_type": (['midpoint', 'heun'], ),
- "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "noise_device": (['gpu', 'cpu'], ),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerDPMPP_2M_SDE",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Combo.Input("solver_type", options=['midpoint', 'heun']),
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Combo.Input("noise_device", options=['gpu', 'cpu']),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, solver_type, eta, s_noise, noise_device):
+ @classmethod
+ def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput:
if noise_device == 'cpu':
sampler_name = "dpmpp_2m_sde"
else:
sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
- return (sampler, )
+ return io.NodeOutput(sampler)
+
+ get_sampler = execute
-class SamplerDPMPP_SDE:
+class SamplerDPMPP_SDE(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "noise_device": (['gpu', 'cpu'], ),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerDPMPP_SDE",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False),
+ io.Combo.Input("noise_device", options=['gpu', 'cpu']),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, eta, s_noise, r, noise_device):
+ @classmethod
+ def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput:
if noise_device == 'cpu':
sampler_name = "dpmpp_sde"
else:
sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
- return (sampler, )
+ return io.NodeOutput(sampler)
-class SamplerDPMPP_2S_Ancestral:
+ get_sampler = execute
+
+class SamplerDPMPP_2S_Ancestral(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerDPMPP_2S_Ancestral",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, eta, s_noise):
+ @classmethod
+ def execute(cls, eta, s_noise) -> io.NodeOutput:
sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
- return (sampler, )
+ return io.NodeOutput(sampler)
-class SamplerEulerAncestral:
+ get_sampler = execute
+
+class SamplerEulerAncestral(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerEulerAncestral",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, eta, s_noise):
+ @classmethod
+ def execute(cls, eta, s_noise) -> io.NodeOutput:
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
- return (sampler, )
+ return io.NodeOutput(sampler)
-class SamplerEulerAncestralCFGPP:
+ get_sampler = execute
+
+class SamplerEulerAncestralCFGPP(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}),
- "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}),
- }}
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerEulerAncestralCFGPP",
+ display_name="SamplerEulerAncestralCFG++",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, eta, s_noise):
+ @classmethod
+ def execute(cls, eta, s_noise) -> io.NodeOutput:
sampler = comfy.samplers.ksampler(
"euler_ancestral_cfg_pp",
{"eta": eta, "s_noise": s_noise})
- return (sampler, )
+ return io.NodeOutput(sampler)
-class SamplerLMS:
+ get_sampler = execute
+
+class SamplerLMS(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"order": ("INT", {"default": 4, "min": 1, "max": 100}),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerLMS",
+ category="sampling/custom_sampling/samplers",
+ inputs=[io.Int.Input("order", default=4, min=1, max=100)],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, order):
+ @classmethod
+ def execute(cls, order) -> io.NodeOutput:
sampler = comfy.samplers.ksampler("lms", {"order": order})
- return (sampler, )
+ return io.NodeOutput(sampler)
-class SamplerDPMAdaptative:
+ get_sampler = execute
+
+class SamplerDPMAdaptative(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"order": ("INT", {"default": 3, "min": 2, "max": 3}),
- "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
- }
- }
- RETURN_TYPES = ("SAMPLER",)
- CATEGORY = "sampling/custom_sampling/samplers"
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerDPMAdaptative",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Int.Input("order", default=3, min=2, max=3),
+ io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- FUNCTION = "get_sampler"
-
- def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise):
+ @classmethod
+ def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput:
sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
"icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
"s_noise":s_noise })
- return (sampler, )
+ return io.NodeOutput(sampler)
+
+ get_sampler = execute
-class SamplerER_SDE(ComfyNodeABC):
+class SamplerER_SDE(io.ComfyNode):
@classmethod
- def INPUT_TYPES(cls) -> InputTypeDict:
- return {
- "required": {
- "solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
- "max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
- "eta": (
- IO.FLOAT,
- {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
- ),
- "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerER_SDE",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
+ io.Int.Input("max_stage", default=3, min=1, max=3),
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- RETURN_TYPES = (IO.SAMPLER,)
- CATEGORY = "sampling/custom_sampling/samplers"
-
- FUNCTION = "get_sampler"
-
- def get_sampler(self, solver_type, max_stage, eta, s_noise):
+ @classmethod
+ def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput:
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
eta = 0
s_noise = 0
@@ -548,32 +610,33 @@ class SamplerER_SDE(ComfyNodeABC):
sampler_name = "er_sde"
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
- return (sampler,)
+ return io.NodeOutput(sampler)
+
+ get_sampler = execute
-class SamplerSASolver(ComfyNodeABC):
+class SamplerSASolver(io.ComfyNode):
@classmethod
- def INPUT_TYPES(cls) -> InputTypeDict:
- return {
- "required": {
- "model": (IO.MODEL, {}),
- "eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},),
- "sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},),
- "sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},),
- "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},),
- "predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}),
- "corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}),
- "use_pece": (IO.BOOLEAN, {}),
- "simple_order_2": (IO.BOOLEAN, {}),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerSASolver",
+ category="sampling/custom_sampling/samplers",
+ inputs=[
+ io.Model.Input("model"),
+ io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
+ io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001),
+ io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001),
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
+ io.Int.Input("predictor_order", default=3, min=1, max=6),
+ io.Int.Input("corrector_order", default=4, min=0, max=6),
+ io.Boolean.Input("use_pece"),
+ io.Boolean.Input("simple_order_2"),
+ ],
+ outputs=[io.Sampler.Output()]
+ )
- RETURN_TYPES = (IO.SAMPLER,)
- CATEGORY = "sampling/custom_sampling/samplers"
-
- FUNCTION = "get_sampler"
-
- def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2):
+ @classmethod
+ def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput:
model_sampling = model.get_model_object("model_sampling")
start_sigma = model_sampling.percent_to_sigma(sde_start_percent)
end_sigma = model_sampling.percent_to_sigma(sde_end_percent)
@@ -591,7 +654,9 @@ class SamplerSASolver(ComfyNodeABC):
"simple_order_2": simple_order_2,
},
)
- return (sampler,)
+ return io.NodeOutput(sampler)
+
+ get_sampler = execute
class Noise_EmptyNoise:
@@ -612,30 +677,31 @@ class Noise_RandomNoise:
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
-class SamplerCustom:
+class SamplerCustom(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "add_noise": ("BOOLEAN", {"default": True}),
- "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
- "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
- "positive": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "sampler": ("SAMPLER", ),
- "sigmas": ("SIGMAS", ),
- "latent_image": ("LATENT", ),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerCustom",
+ category="sampling/custom_sampling",
+ inputs=[
+ io.Model.Input("model"),
+ io.Boolean.Input("add_noise", default=True),
+ io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
+ io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
+ io.Conditioning.Input("positive"),
+ io.Conditioning.Input("negative"),
+ io.Sampler.Input("sampler"),
+ io.Sigmas.Input("sigmas"),
+ io.Latent.Input("latent_image"),
+ ],
+ outputs=[
+ io.Latent.Output(display_name="output"),
+ io.Latent.Output(display_name="denoised_output"),
+ ]
+ )
- RETURN_TYPES = ("LATENT","LATENT")
- RETURN_NAMES = ("output", "denoised_output")
-
- FUNCTION = "sample"
-
- CATEGORY = "sampling/custom_sampling"
-
- def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
+ @classmethod
+ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput:
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
@@ -664,52 +730,58 @@ class SamplerCustom:
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
- return (out, out_denoised)
+ return io.NodeOutput(out, out_denoised)
+
+ sample = execute
class Guider_Basic(comfy.samplers.CFGGuider):
def set_conds(self, positive):
self.inner_set_conds({"positive": positive})
-class BasicGuider:
+class BasicGuider(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "conditioning": ("CONDITIONING", ),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="BasicGuider",
+ category="sampling/custom_sampling/guiders",
+ inputs=[
+ io.Model.Input("model"),
+ io.Conditioning.Input("conditioning"),
+ ],
+ outputs=[io.Guider.Output()]
+ )
- RETURN_TYPES = ("GUIDER",)
-
- FUNCTION = "get_guider"
- CATEGORY = "sampling/custom_sampling/guiders"
-
- def get_guider(self, model, conditioning):
+ @classmethod
+ def execute(cls, model, conditioning) -> io.NodeOutput:
guider = Guider_Basic(model)
guider.set_conds(conditioning)
- return (guider,)
+ return io.NodeOutput(guider)
-class CFGGuider:
+ get_guider = execute
+
+class CFGGuider(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "positive": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="CFGGuider",
+ category="sampling/custom_sampling/guiders",
+ inputs=[
+ io.Model.Input("model"),
+ io.Conditioning.Input("positive"),
+ io.Conditioning.Input("negative"),
+ io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
+ ],
+ outputs=[io.Guider.Output()]
+ )
- RETURN_TYPES = ("GUIDER",)
-
- FUNCTION = "get_guider"
- CATEGORY = "sampling/custom_sampling/guiders"
-
- def get_guider(self, model, positive, negative, cfg):
+ @classmethod
+ def execute(cls, model, positive, negative, cfg) -> io.NodeOutput:
guider = comfy.samplers.CFGGuider(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg)
- return (guider,)
+ return io.NodeOutput(guider)
+
+ get_guider = execute
class Guider_DualCFG(comfy.samplers.CFGGuider):
def set_cfg(self, cfg1, cfg2, nested=False):
@@ -740,84 +812,88 @@ class Guider_DualCFG(comfy.samplers.CFGGuider):
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
-class DualCFGGuider:
+class DualCFGGuider(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "cond1": ("CONDITIONING", ),
- "cond2": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
- "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
- "style": (["regular", "nested"],),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="DualCFGGuider",
+ category="sampling/custom_sampling/guiders",
+ inputs=[
+ io.Model.Input("model"),
+ io.Conditioning.Input("cond1"),
+ io.Conditioning.Input("cond2"),
+ io.Conditioning.Input("negative"),
+ io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
+ io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
+ io.Combo.Input("style", options=["regular", "nested"]),
+ ],
+ outputs=[io.Guider.Output()]
+ )
- RETURN_TYPES = ("GUIDER",)
-
- FUNCTION = "get_guider"
- CATEGORY = "sampling/custom_sampling/guiders"
-
- def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style):
+ @classmethod
+ def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput:
guider = Guider_DualCFG(model)
guider.set_conds(cond1, cond2, negative)
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
- return (guider,)
+ return io.NodeOutput(guider)
-class DisableNoise:
+ get_guider = execute
+
+class DisableNoise(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":{
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="DisableNoise",
+ category="sampling/custom_sampling/noise",
+ inputs=[],
+ outputs=[io.Noise.Output()]
+ )
- RETURN_TYPES = ("NOISE",)
- FUNCTION = "get_noise"
- CATEGORY = "sampling/custom_sampling/noise"
-
- def get_noise(self):
- return (Noise_EmptyNoise(),)
-
-
-class RandomNoise(DisableNoise):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "noise_seed": ("INT", {
- "default": 0,
- "min": 0,
- "max": 0xffffffffffffffff,
- "control_after_generate": True,
- }),
- }
- }
+ def execute(cls) -> io.NodeOutput:
+ return io.NodeOutput(Noise_EmptyNoise())
- def get_noise(self, noise_seed):
- return (Noise_RandomNoise(noise_seed),)
+ get_noise = execute
-class SamplerCustomAdvanced:
+class RandomNoise(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"noise": ("NOISE", ),
- "guider": ("GUIDER", ),
- "sampler": ("SAMPLER", ),
- "sigmas": ("SIGMAS", ),
- "latent_image": ("LATENT", ),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="RandomNoise",
+ category="sampling/custom_sampling/noise",
+ inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)],
+ outputs=[io.Noise.Output()]
+ )
- RETURN_TYPES = ("LATENT","LATENT")
- RETURN_NAMES = ("output", "denoised_output")
+ @classmethod
+ def execute(cls, noise_seed) -> io.NodeOutput:
+ return io.NodeOutput(Noise_RandomNoise(noise_seed))
- FUNCTION = "sample"
+ get_noise = execute
- CATEGORY = "sampling/custom_sampling"
- def sample(self, noise, guider, sampler, sigmas, latent_image):
+class SamplerCustomAdvanced(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SamplerCustomAdvanced",
+ category="sampling/custom_sampling",
+ inputs=[
+ io.Noise.Input("noise"),
+ io.Guider.Input("guider"),
+ io.Sampler.Input("sampler"),
+ io.Sigmas.Input("sigmas"),
+ io.Latent.Input("latent_image"),
+ ],
+ outputs=[
+ io.Latent.Output(display_name="output"),
+ io.Latent.Output(display_name="denoised_output"),
+ ]
+ )
+
+ @classmethod
+ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
@@ -842,28 +918,32 @@ class SamplerCustomAdvanced:
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
- return (out, out_denoised)
+ return io.NodeOutput(out, out_denoised)
-class AddNoise:
+ sample = execute
+
+class AddNoise(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"model": ("MODEL",),
- "noise": ("NOISE", ),
- "sigmas": ("SIGMAS", ),
- "latent_image": ("LATENT", ),
- }
- }
+ def define_schema(cls):
+ return io.Schema(
+ node_id="AddNoise",
+ category="_for_testing/custom_sampling/noise",
+ is_experimental=True,
+ inputs=[
+ io.Model.Input("model"),
+ io.Noise.Input("noise"),
+ io.Sigmas.Input("sigmas"),
+ io.Latent.Input("latent_image"),
+ ],
+ outputs=[
+ io.Latent.Output(),
+ ]
+ )
- RETURN_TYPES = ("LATENT",)
-
- FUNCTION = "add_noise"
-
- CATEGORY = "_for_testing/custom_sampling/noise"
-
- def add_noise(self, model, noise, sigmas, latent_image):
+ @classmethod
+ def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput:
if len(sigmas) == 0:
- return latent_image
+ return io.NodeOutput(latent_image)
latent = latent_image
latent_image = latent["samples"]
@@ -887,46 +967,50 @@ class AddNoise:
out = latent.copy()
out["samples"] = noisy
- return (out,)
+ return io.NodeOutput(out)
+
+ add_noise = execute
-NODE_CLASS_MAPPINGS = {
- "SamplerCustom": SamplerCustom,
- "BasicScheduler": BasicScheduler,
- "KarrasScheduler": KarrasScheduler,
- "ExponentialScheduler": ExponentialScheduler,
- "PolyexponentialScheduler": PolyexponentialScheduler,
- "LaplaceScheduler": LaplaceScheduler,
- "VPScheduler": VPScheduler,
- "BetaSamplingScheduler": BetaSamplingScheduler,
- "SDTurboScheduler": SDTurboScheduler,
- "KSamplerSelect": KSamplerSelect,
- "SamplerEulerAncestral": SamplerEulerAncestral,
- "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP,
- "SamplerLMS": SamplerLMS,
- "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
- "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
- "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
- "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
- "SamplerDPMAdaptative": SamplerDPMAdaptative,
- "SamplerER_SDE": SamplerER_SDE,
- "SamplerSASolver": SamplerSASolver,
- "SplitSigmas": SplitSigmas,
- "SplitSigmasDenoise": SplitSigmasDenoise,
- "FlipSigmas": FlipSigmas,
- "SetFirstSigma": SetFirstSigma,
- "ExtendIntermediateSigmas": ExtendIntermediateSigmas,
- "SamplingPercentToSigma": SamplingPercentToSigma,
+class CustomSamplersExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ SamplerCustom,
+ BasicScheduler,
+ KarrasScheduler,
+ ExponentialScheduler,
+ PolyexponentialScheduler,
+ LaplaceScheduler,
+ VPScheduler,
+ BetaSamplingScheduler,
+ SDTurboScheduler,
+ KSamplerSelect,
+ SamplerEulerAncestral,
+ SamplerEulerAncestralCFGPP,
+ SamplerLMS,
+ SamplerDPMPP_3M_SDE,
+ SamplerDPMPP_2M_SDE,
+ SamplerDPMPP_SDE,
+ SamplerDPMPP_2S_Ancestral,
+ SamplerDPMAdaptative,
+ SamplerER_SDE,
+ SamplerSASolver,
+ SplitSigmas,
+ SplitSigmasDenoise,
+ FlipSigmas,
+ SetFirstSigma,
+ ExtendIntermediateSigmas,
+ SamplingPercentToSigma,
+ CFGGuider,
+ DualCFGGuider,
+ BasicGuider,
+ RandomNoise,
+ DisableNoise,
+ AddNoise,
+ SamplerCustomAdvanced,
+ ]
- "CFGGuider": CFGGuider,
- "DualCFGGuider": DualCFGGuider,
- "BasicGuider": BasicGuider,
- "RandomNoise": RandomNoise,
- "DisableNoise": DisableNoise,
- "AddNoise": AddNoise,
- "SamplerCustomAdvanced": SamplerCustomAdvanced,
-}
-NODE_DISPLAY_NAME_MAPPINGS = {
- "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++",
-}
+async def comfy_entrypoint() -> CustomSamplersExtension:
+ return CustomSamplersExtension()
diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py
new file mode 100644
index 000000000..4789d7d53
--- /dev/null
+++ b/comfy_extras/nodes_dataset.py
@@ -0,0 +1,1432 @@
+import logging
+import os
+import json
+
+import numpy as np
+import torch
+from PIL import Image
+from typing_extensions import override
+
+import folder_paths
+import node_helpers
+from comfy_api.latest import ComfyExtension, io
+
+
+def load_and_process_images(image_files, input_dir):
+ """Utility function to load and process a list of images.
+
+ Args:
+ image_files: List of image filenames
+ input_dir: Base directory containing the images
+ resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
+
+ Returns:
+ torch.Tensor: Batch of processed images
+ """
+ if not image_files:
+ raise ValueError("No valid images found in input")
+
+ output_images = []
+
+ for file in image_files:
+ image_path = os.path.join(input_dir, file)
+ img = node_helpers.pillow(Image.open, image_path)
+
+ if img.mode == "I":
+ img = img.point(lambda i: i * (1 / 255))
+ img = img.convert("RGB")
+ img_array = np.array(img).astype(np.float32) / 255.0
+ img_tensor = torch.from_numpy(img_array)[None,]
+ output_images.append(img_tensor)
+
+ return output_images
+
+
+class LoadImageDataSetFromFolderNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="LoadImageDataSetFromFolder",
+ display_name="Load Image Dataset from Folder",
+ category="dataset",
+ is_experimental=True,
+ inputs=[
+ io.Combo.Input(
+ "folder",
+ options=folder_paths.get_input_subfolders(),
+ tooltip="The folder to load images from.",
+ )
+ ],
+ outputs=[
+ io.Image.Output(
+ display_name="images",
+ is_output_list=True,
+ tooltip="List of loaded images",
+ )
+ ],
+ )
+
+ @classmethod
+ def execute(cls, folder):
+ sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
+ valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
+ image_files = [
+ f
+ for f in os.listdir(sub_input_dir)
+ if any(f.lower().endswith(ext) for ext in valid_extensions)
+ ]
+ output_tensor = load_and_process_images(image_files, sub_input_dir)
+ return io.NodeOutput(output_tensor)
+
+
+class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="LoadImageTextDataSetFromFolder",
+ display_name="Load Image and Text Dataset from Folder",
+ category="dataset",
+ is_experimental=True,
+ inputs=[
+ io.Combo.Input(
+ "folder",
+ options=folder_paths.get_input_subfolders(),
+ tooltip="The folder to load images from.",
+ )
+ ],
+ outputs=[
+ io.Image.Output(
+ display_name="images",
+ is_output_list=True,
+ tooltip="List of loaded images",
+ ),
+ io.String.Output(
+ display_name="texts",
+ is_output_list=True,
+ tooltip="List of text captions",
+ ),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, folder):
+ logging.info(f"Loading images from folder: {folder}")
+
+ sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
+ valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
+
+ image_files = []
+ for item in os.listdir(sub_input_dir):
+ path = os.path.join(sub_input_dir, item)
+ if any(item.lower().endswith(ext) for ext in valid_extensions):
+ image_files.append(path)
+ elif os.path.isdir(path):
+ # Support kohya-ss/sd-scripts folder structure
+ repeat = 1
+ if item.split("_")[0].isdigit():
+ repeat = int(item.split("_")[0])
+ image_files.extend(
+ [
+ os.path.join(path, f)
+ for f in os.listdir(path)
+ if any(f.lower().endswith(ext) for ext in valid_extensions)
+ ]
+ * repeat
+ )
+
+ caption_file_path = [
+ f.replace(os.path.splitext(f)[1], ".txt") for f in image_files
+ ]
+ captions = []
+ for caption_file in caption_file_path:
+ caption_path = os.path.join(sub_input_dir, caption_file)
+ if os.path.exists(caption_path):
+ with open(caption_path, "r", encoding="utf-8") as f:
+ caption = f.read().strip()
+ captions.append(caption)
+ else:
+ captions.append("")
+
+ output_tensor = load_and_process_images(image_files, sub_input_dir)
+
+ logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
+ return io.NodeOutput(output_tensor, captions)
+
+
+def save_images_to_folder(image_list, output_dir, prefix="image"):
+ """Utility function to save a list of image tensors to disk.
+
+ Args:
+ image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W])
+ output_dir: Directory to save images to
+ prefix: Filename prefix
+
+ Returns:
+ List of saved filenames
+ """
+ os.makedirs(output_dir, exist_ok=True)
+ saved_files = []
+
+ for idx, img_tensor in enumerate(image_list):
+ # Handle different tensor shapes
+ if isinstance(img_tensor, torch.Tensor):
+ # Remove batch dimension if present [1, H, W, C] -> [H, W, C]
+ if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
+ img_tensor = img_tensor.squeeze(0)
+
+ # If tensor is [C, H, W], permute to [H, W, C]
+ if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]:
+ if (
+ img_tensor.shape[0] <= 4
+ and img_tensor.shape[1] > 4
+ and img_tensor.shape[2] > 4
+ ):
+ img_tensor = img_tensor.permute(1, 2, 0)
+
+ # Convert to numpy and scale to 0-255
+ img_array = img_tensor.cpu().numpy()
+ img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8)
+
+ # Convert to PIL Image
+ img = Image.fromarray(img_array)
+ else:
+ raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
+
+ # Save image
+ filename = f"{prefix}_{idx:05d}.png"
+ filepath = os.path.join(output_dir, filename)
+ img.save(filepath)
+ saved_files.append(filename)
+
+ return saved_files
+
+
+class SaveImageDataSetToFolderNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SaveImageDataSetToFolder",
+ display_name="Save Image Dataset to Folder",
+ category="dataset",
+ is_experimental=True,
+ is_output_node=True,
+ is_input_list=True, # Receive images as list
+ inputs=[
+ io.Image.Input("images", tooltip="List of images to save."),
+ io.String.Input(
+ "folder_name",
+ default="dataset",
+ tooltip="Name of the folder to save images to (inside output directory).",
+ ),
+ io.String.Input(
+ "filename_prefix",
+ default="image",
+ tooltip="Prefix for saved image filenames.",
+ ),
+ ],
+ outputs=[],
+ )
+
+ @classmethod
+ def execute(cls, images, folder_name, filename_prefix):
+ # Extract scalar values
+ folder_name = folder_name[0]
+ filename_prefix = filename_prefix[0]
+
+ output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
+ saved_files = save_images_to_folder(images, output_dir, filename_prefix)
+
+ logging.info(f"Saved {len(saved_files)} images to {output_dir}.")
+ return io.NodeOutput()
+
+
+class SaveImageTextDataSetToFolderNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SaveImageTextDataSetToFolder",
+ display_name="Save Image and Text Dataset to Folder",
+ category="dataset",
+ is_experimental=True,
+ is_output_node=True,
+ is_input_list=True, # Receive both images and texts as lists
+ inputs=[
+ io.Image.Input("images", tooltip="List of images to save."),
+ io.String.Input("texts", tooltip="List of text captions to save."),
+ io.String.Input(
+ "folder_name",
+ default="dataset",
+ tooltip="Name of the folder to save images to (inside output directory).",
+ ),
+ io.String.Input(
+ "filename_prefix",
+ default="image",
+ tooltip="Prefix for saved image filenames.",
+ ),
+ ],
+ outputs=[],
+ )
+
+ @classmethod
+ def execute(cls, images, texts, folder_name, filename_prefix):
+ # Extract scalar values
+ folder_name = folder_name[0]
+ filename_prefix = filename_prefix[0]
+
+ output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
+ saved_files = save_images_to_folder(images, output_dir, filename_prefix)
+
+ # Save captions
+ for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
+ caption_filename = filename.replace(".png", ".txt")
+ caption_path = os.path.join(output_dir, caption_filename)
+ with open(caption_path, "w", encoding="utf-8") as f:
+ f.write(caption)
+
+ logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
+ return io.NodeOutput()
+
+
+# ========== Helper Functions for Transform Nodes ==========
+
+
+def tensor_to_pil(img_tensor):
+ """Convert tensor to PIL Image."""
+ if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
+ img_tensor = img_tensor.squeeze(0)
+ img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
+ return Image.fromarray(img_array)
+
+
+def pil_to_tensor(img):
+ """Convert PIL Image to tensor."""
+ img_array = np.array(img).astype(np.float32) / 255.0
+ return torch.from_numpy(img_array)[None,]
+
+
+# ========== Base Classes for Transform Nodes ==========
+
+
+class ImageProcessingNode(io.ComfyNode):
+ """Base class for image processing nodes that operate on images.
+
+ Child classes should set:
+ node_id: Unique node identifier (required)
+ display_name: Display name (optional, defaults to node_id)
+ description: Node description (optional)
+ extra_inputs: List of additional io.Input objects beyond "images" (optional)
+ is_group_process: None (auto-detect), True (group), or False (individual) (optional)
+ is_output_list: True (list output) or False (single output) (optional, default True)
+
+ Child classes must implement ONE of:
+ _process(cls, image, **kwargs) -> tensor (for single-item processing)
+ _group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
+ """
+
+ node_id = None
+ display_name = None
+ description = None
+ extra_inputs = []
+ is_group_process = None # None = auto-detect, True/False = explicit
+ is_output_list = None # None = auto-detect based on processing mode
+
+ @classmethod
+ def _detect_processing_mode(cls):
+ """Detect whether this node uses group or individual processing.
+
+ Returns:
+ bool: True if group processing, False if individual processing
+ """
+ # Explicit setting takes precedence
+ if cls.is_group_process is not None:
+ return cls.is_group_process
+
+ # Check which method is overridden by looking at the defining class in MRO
+ base_class = ImageProcessingNode
+
+ # Find which class in MRO defines _process
+ process_definer = None
+ for klass in cls.__mro__:
+ if "_process" in klass.__dict__:
+ process_definer = klass
+ break
+
+ # Find which class in MRO defines _group_process
+ group_definer = None
+ for klass in cls.__mro__:
+ if "_group_process" in klass.__dict__:
+ group_definer = klass
+ break
+
+ # Check what was overridden (not defined in base class)
+ has_process = process_definer is not None and process_definer is not base_class
+ has_group = group_definer is not None and group_definer is not base_class
+
+ if has_process and has_group:
+ raise ValueError(
+ f"{cls.__name__}: Cannot override both _process and _group_process. "
+ "Override only one, or set is_group_process explicitly."
+ )
+ if not has_process and not has_group:
+ raise ValueError(
+ f"{cls.__name__}: Must override either _process or _group_process"
+ )
+
+ return has_group
+
+ @classmethod
+ def define_schema(cls):
+ if cls.node_id is None:
+ raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
+
+ is_group = cls._detect_processing_mode()
+
+ # Auto-detect is_output_list if not explicitly set
+ # Single processing: False (backend collects results into list)
+ # Group processing: True by default (can be False for single-output nodes)
+ output_is_list = (
+ cls.is_output_list if cls.is_output_list is not None else is_group
+ )
+
+ inputs = [
+ io.Image.Input(
+ "images",
+ tooltip=(
+ "List of images to process." if is_group else "Image to process."
+ ),
+ )
+ ]
+ inputs.extend(cls.extra_inputs)
+
+ return io.Schema(
+ node_id=cls.node_id,
+ display_name=cls.display_name or cls.node_id,
+ category="dataset/image",
+ is_experimental=True,
+ is_input_list=is_group, # True for group, False for individual
+ inputs=inputs,
+ outputs=[
+ io.Image.Output(
+ display_name="images",
+ is_output_list=output_is_list,
+ tooltip="Processed images",
+ )
+ ],
+ )
+
+ @classmethod
+ def execute(cls, images, **kwargs):
+ """Execute the node. Routes to _process or _group_process based on mode."""
+ is_group = cls._detect_processing_mode()
+
+ # Extract scalar values from lists for parameters
+ params = {}
+ for k, v in kwargs.items():
+ if isinstance(v, list) and len(v) == 1:
+ params[k] = v[0]
+ else:
+ params[k] = v
+
+ if is_group:
+ # Group processing: images is list, call _group_process
+ result = cls._group_process(images, **params)
+ else:
+ # Individual processing: images is single item, call _process
+ result = cls._process(images, **params)
+
+ return io.NodeOutput(result)
+
+ @classmethod
+ def _process(cls, image, **kwargs):
+ """Override this method for single-item processing.
+
+ Args:
+ image: tensor - Single image tensor
+ **kwargs: Additional parameters (already extracted from lists)
+
+ Returns:
+ tensor - Processed image
+ """
+ raise NotImplementedError(f"{cls.__name__} must implement _process method")
+
+ @classmethod
+ def _group_process(cls, images, **kwargs):
+ """Override this method for group processing.
+
+ Args:
+ images: list[tensor] - List of image tensors
+ **kwargs: Additional parameters (already extracted from lists)
+
+ Returns:
+ list[tensor] - Processed images
+ """
+ raise NotImplementedError(
+ f"{cls.__name__} must implement _group_process method"
+ )
+
+
+class TextProcessingNode(io.ComfyNode):
+ """Base class for text processing nodes that operate on texts.
+
+ Child classes should set:
+ node_id: Unique node identifier (required)
+ display_name: Display name (optional, defaults to node_id)
+ description: Node description (optional)
+ extra_inputs: List of additional io.Input objects beyond "texts" (optional)
+ is_group_process: None (auto-detect), True (group), or False (individual) (optional)
+ is_output_list: True (list output) or False (single output) (optional, default True)
+
+ Child classes must implement ONE of:
+ _process(cls, text, **kwargs) -> str (for single-item processing)
+ _group_process(cls, texts, **kwargs) -> list[str] (for group processing)
+ """
+
+ node_id = None
+ display_name = None
+ description = None
+ extra_inputs = []
+ is_group_process = None # None = auto-detect, True/False = explicit
+ is_output_list = None # None = auto-detect based on processing mode
+
+ @classmethod
+ def _detect_processing_mode(cls):
+ """Detect whether this node uses group or individual processing.
+
+ Returns:
+ bool: True if group processing, False if individual processing
+ """
+ # Explicit setting takes precedence
+ if cls.is_group_process is not None:
+ return cls.is_group_process
+
+ # Check which method is overridden by looking at the defining class in MRO
+ base_class = TextProcessingNode
+
+ # Find which class in MRO defines _process
+ process_definer = None
+ for klass in cls.__mro__:
+ if "_process" in klass.__dict__:
+ process_definer = klass
+ break
+
+ # Find which class in MRO defines _group_process
+ group_definer = None
+ for klass in cls.__mro__:
+ if "_group_process" in klass.__dict__:
+ group_definer = klass
+ break
+
+ # Check what was overridden (not defined in base class)
+ has_process = process_definer is not None and process_definer is not base_class
+ has_group = group_definer is not None and group_definer is not base_class
+
+ if has_process and has_group:
+ raise ValueError(
+ f"{cls.__name__}: Cannot override both _process and _group_process. "
+ "Override only one, or set is_group_process explicitly."
+ )
+ if not has_process and not has_group:
+ raise ValueError(
+ f"{cls.__name__}: Must override either _process or _group_process"
+ )
+
+ return has_group
+
+ @classmethod
+ def define_schema(cls):
+ if cls.node_id is None:
+ raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
+
+ is_group = cls._detect_processing_mode()
+
+ inputs = [
+ io.String.Input(
+ "texts",
+ tooltip="List of texts to process." if is_group else "Text to process.",
+ )
+ ]
+ inputs.extend(cls.extra_inputs)
+
+ return io.Schema(
+ node_id=cls.node_id,
+ display_name=cls.display_name or cls.node_id,
+ category="dataset/text",
+ is_experimental=True,
+ is_input_list=is_group, # True for group, False for individual
+ inputs=inputs,
+ outputs=[
+ io.String.Output(
+ display_name="texts",
+ is_output_list=cls.is_output_list,
+ tooltip="Processed texts",
+ )
+ ],
+ )
+
+ @classmethod
+ def execute(cls, texts, **kwargs):
+ """Execute the node. Routes to _process or _group_process based on mode."""
+ is_group = cls._detect_processing_mode()
+
+ # Extract scalar values from lists for parameters
+ params = {}
+ for k, v in kwargs.items():
+ if isinstance(v, list) and len(v) == 1:
+ params[k] = v[0]
+ else:
+ params[k] = v
+
+ if is_group:
+ # Group processing: texts is list, call _group_process
+ result = cls._group_process(texts, **params)
+ else:
+ # Individual processing: texts is single item, call _process
+ result = cls._process(texts, **params)
+
+ # Wrap result based on is_output_list
+ if cls.is_output_list:
+ # Result should already be a list (or will be for individual)
+ return io.NodeOutput(result if is_group else [result])
+ else:
+ # Single output - wrap in list for NodeOutput
+ return io.NodeOutput([result])
+
+ @classmethod
+ def _process(cls, text, **kwargs):
+ """Override this method for single-item processing.
+
+ Args:
+ text: str - Single text string
+ **kwargs: Additional parameters (already extracted from lists)
+
+ Returns:
+ str - Processed text
+ """
+ raise NotImplementedError(f"{cls.__name__} must implement _process method")
+
+ @classmethod
+ def _group_process(cls, texts, **kwargs):
+ """Override this method for group processing.
+
+ Args:
+ texts: list[str] - List of text strings
+ **kwargs: Additional parameters (already extracted from lists)
+
+ Returns:
+ list[str] - Processed texts
+ """
+ raise NotImplementedError(
+ f"{cls.__name__} must implement _group_process method"
+ )
+
+
+# ========== Image Transform Nodes ==========
+
+
+class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
+ node_id = "ResizeImagesByShorterEdge"
+ display_name = "Resize Images by Shorter Edge"
+ description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
+ extra_inputs = [
+ io.Int.Input(
+ "shorter_edge",
+ default=512,
+ min=1,
+ max=8192,
+ tooltip="Target length for the shorter edge.",
+ ),
+ ]
+
+ @classmethod
+ def _process(cls, image, shorter_edge):
+ img = tensor_to_pil(image)
+ w, h = img.size
+ if w < h:
+ new_w = shorter_edge
+ new_h = int(h * (shorter_edge / w))
+ else:
+ new_h = shorter_edge
+ new_w = int(w * (shorter_edge / h))
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
+ return pil_to_tensor(img)
+
+
+class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
+ node_id = "ResizeImagesByLongerEdge"
+ display_name = "Resize Images by Longer Edge"
+ description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
+ extra_inputs = [
+ io.Int.Input(
+ "longer_edge",
+ default=1024,
+ min=1,
+ max=8192,
+ tooltip="Target length for the longer edge.",
+ ),
+ ]
+
+ @classmethod
+ def _process(cls, image, longer_edge):
+ img = tensor_to_pil(image)
+ w, h = img.size
+ if w > h:
+ new_w = longer_edge
+ new_h = int(h * (longer_edge / w))
+ else:
+ new_h = longer_edge
+ new_w = int(w * (longer_edge / h))
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
+ return pil_to_tensor(img)
+
+
+class CenterCropImagesNode(ImageProcessingNode):
+ node_id = "CenterCropImages"
+ display_name = "Center Crop Images"
+ description = "Center crop all images to the specified dimensions."
+ extra_inputs = [
+ io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
+ io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
+ ]
+
+ @classmethod
+ def _process(cls, image, width, height):
+ img = tensor_to_pil(image)
+ left = max(0, (img.width - width) // 2)
+ top = max(0, (img.height - height) // 2)
+ right = min(img.width, left + width)
+ bottom = min(img.height, top + height)
+ img = img.crop((left, top, right, bottom))
+ return pil_to_tensor(img)
+
+
+class RandomCropImagesNode(ImageProcessingNode):
+ node_id = "RandomCropImages"
+ display_name = "Random Crop Images"
+ description = (
+ "Randomly crop all images to the specified dimensions (for data augmentation)."
+ )
+ extra_inputs = [
+ io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
+ io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
+ io.Int.Input(
+ "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
+ ),
+ ]
+
+ @classmethod
+ def _process(cls, image, width, height, seed):
+ np.random.seed(seed % (2**32 - 1))
+ img = tensor_to_pil(image)
+ max_left = max(0, img.width - width)
+ max_top = max(0, img.height - height)
+ left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
+ top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
+ right = min(img.width, left + width)
+ bottom = min(img.height, top + height)
+ img = img.crop((left, top, right, bottom))
+ return pil_to_tensor(img)
+
+
+class NormalizeImagesNode(ImageProcessingNode):
+ node_id = "NormalizeImages"
+ display_name = "Normalize Images"
+ description = "Normalize images using mean and standard deviation."
+ extra_inputs = [
+ io.Float.Input(
+ "mean",
+ default=0.5,
+ min=0.0,
+ max=1.0,
+ tooltip="Mean value for normalization.",
+ ),
+ io.Float.Input(
+ "std",
+ default=0.5,
+ min=0.001,
+ max=1.0,
+ tooltip="Standard deviation for normalization.",
+ ),
+ ]
+
+ @classmethod
+ def _process(cls, image, mean, std):
+ return (image - mean) / std
+
+
+class AdjustBrightnessNode(ImageProcessingNode):
+ node_id = "AdjustBrightness"
+ display_name = "Adjust Brightness"
+ description = "Adjust brightness of all images."
+ extra_inputs = [
+ io.Float.Input(
+ "factor",
+ default=1.0,
+ min=0.0,
+ max=2.0,
+ tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.",
+ ),
+ ]
+
+ @classmethod
+ def _process(cls, image, factor):
+ return (image * factor).clamp(0.0, 1.0)
+
+
+class AdjustContrastNode(ImageProcessingNode):
+ node_id = "AdjustContrast"
+ display_name = "Adjust Contrast"
+ description = "Adjust contrast of all images."
+ extra_inputs = [
+ io.Float.Input(
+ "factor",
+ default=1.0,
+ min=0.0,
+ max=2.0,
+ tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.",
+ ),
+ ]
+
+ @classmethod
+ def _process(cls, image, factor):
+ return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
+
+
+class ShuffleDatasetNode(ImageProcessingNode):
+ node_id = "ShuffleDataset"
+ display_name = "Shuffle Image Dataset"
+ description = "Randomly shuffle the order of images in the dataset."
+ is_group_process = True # Requires full list to shuffle
+ extra_inputs = [
+ io.Int.Input(
+ "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
+ ),
+ ]
+
+ @classmethod
+ def _group_process(cls, images, seed):
+ np.random.seed(seed % (2**32 - 1))
+ indices = np.random.permutation(len(images))
+ return [images[i] for i in indices]
+
+
+class ShuffleImageTextDatasetNode(io.ComfyNode):
+ """Special node that shuffles both images and texts together."""
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="ShuffleImageTextDataset",
+ display_name="Shuffle Image-Text Dataset",
+ category="dataset/image",
+ is_experimental=True,
+ is_input_list=True,
+ inputs=[
+ io.Image.Input("images", tooltip="List of images to shuffle."),
+ io.String.Input("texts", tooltip="List of texts to shuffle."),
+ io.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=0xFFFFFFFFFFFFFFFF,
+ tooltip="Random seed.",
+ ),
+ ],
+ outputs=[
+ io.Image.Output(
+ display_name="images",
+ is_output_list=True,
+ tooltip="Shuffled images",
+ ),
+ io.String.Output(
+ display_name="texts", is_output_list=True, tooltip="Shuffled texts"
+ ),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, images, texts, seed):
+ seed = seed[0] # Extract scalar
+ np.random.seed(seed % (2**32 - 1))
+ indices = np.random.permutation(len(images))
+ shuffled_images = [images[i] for i in indices]
+ shuffled_texts = [texts[i] for i in indices]
+ return io.NodeOutput(shuffled_images, shuffled_texts)
+
+
+# ========== Text Transform Nodes ==========
+
+
+class TextToLowercaseNode(TextProcessingNode):
+ node_id = "TextToLowercase"
+ display_name = "Text to Lowercase"
+ description = "Convert all texts to lowercase."
+
+ @classmethod
+ def _process(cls, text):
+ return text.lower()
+
+
+class TextToUppercaseNode(TextProcessingNode):
+ node_id = "TextToUppercase"
+ display_name = "Text to Uppercase"
+ description = "Convert all texts to uppercase."
+
+ @classmethod
+ def _process(cls, text):
+ return text.upper()
+
+
+class TruncateTextNode(TextProcessingNode):
+ node_id = "TruncateText"
+ display_name = "Truncate Text"
+ description = "Truncate all texts to a maximum length."
+ extra_inputs = [
+ io.Int.Input(
+ "max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
+ ),
+ ]
+
+ @classmethod
+ def _process(cls, text, max_length):
+ return text[:max_length]
+
+
+class AddTextPrefixNode(TextProcessingNode):
+ node_id = "AddTextPrefix"
+ display_name = "Add Text Prefix"
+ description = "Add a prefix to all texts."
+ extra_inputs = [
+ io.String.Input("prefix", default="", tooltip="Prefix to add."),
+ ]
+
+ @classmethod
+ def _process(cls, text, prefix):
+ return prefix + text
+
+
+class AddTextSuffixNode(TextProcessingNode):
+ node_id = "AddTextSuffix"
+ display_name = "Add Text Suffix"
+ description = "Add a suffix to all texts."
+ extra_inputs = [
+ io.String.Input("suffix", default="", tooltip="Suffix to add."),
+ ]
+
+ @classmethod
+ def _process(cls, text, suffix):
+ return text + suffix
+
+
+class ReplaceTextNode(TextProcessingNode):
+ node_id = "ReplaceText"
+ display_name = "Replace Text"
+ description = "Replace text in all texts."
+ extra_inputs = [
+ io.String.Input("find", default="", tooltip="Text to find."),
+ io.String.Input("replace", default="", tooltip="Text to replace with."),
+ ]
+
+ @classmethod
+ def _process(cls, text, find, replace):
+ return text.replace(find, replace)
+
+
+class StripWhitespaceNode(TextProcessingNode):
+ node_id = "StripWhitespace"
+ display_name = "Strip Whitespace"
+ description = "Strip leading and trailing whitespace from all texts."
+
+ @classmethod
+ def _process(cls, text):
+ return text.strip()
+
+
+# ========== Group Processing Example Nodes ==========
+
+
+class ImageDeduplicationNode(ImageProcessingNode):
+ """Remove duplicate or very similar images from the dataset using perceptual hashing."""
+
+ node_id = "ImageDeduplication"
+ display_name = "Image Deduplication"
+ description = "Remove duplicate or very similar images from the dataset."
+ is_group_process = True # Requires full list to compare images
+ extra_inputs = [
+ io.Float.Input(
+ "similarity_threshold",
+ default=0.95,
+ min=0.0,
+ max=1.0,
+ tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
+ ),
+ ]
+
+ @classmethod
+ def _group_process(cls, images, similarity_threshold):
+ """Remove duplicate images using perceptual hashing."""
+ if len(images) == 0:
+ return []
+
+ # Compute simple perceptual hash for each image
+ def compute_hash(img_tensor):
+ """Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
+ img = tensor_to_pil(img_tensor)
+ # Resize to 8x8
+ img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
+ # Get pixels
+ pixels = list(img_small.getdata())
+ # Compute average
+ avg = sum(pixels) / len(pixels)
+ # Create hash (1 if above average, 0 otherwise)
+ hash_bits = "".join("1" if p > avg else "0" for p in pixels)
+ return hash_bits
+
+ def hamming_distance(hash1, hash2):
+ """Compute Hamming distance between two hash strings."""
+ return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
+
+ # Compute hashes for all images
+ hashes = [compute_hash(img) for img in images]
+
+ # Find duplicates
+ keep_indices = []
+ for i in range(len(images)):
+ is_duplicate = False
+ for j in keep_indices:
+ # Compare hashes
+ distance = hamming_distance(hashes[i], hashes[j])
+ similarity = 1.0 - (distance / 64.0) # 64 bits total
+ if similarity >= similarity_threshold:
+ is_duplicate = True
+ logging.info(
+ f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
+ )
+ break
+
+ if not is_duplicate:
+ keep_indices.append(i)
+
+ # Return only unique images
+ unique_images = [images[i] for i in keep_indices]
+ logging.info(
+ f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
+ )
+ return unique_images
+
+
+class ImageGridNode(ImageProcessingNode):
+ """Combine multiple images into a single grid/collage."""
+
+ node_id = "ImageGrid"
+ display_name = "Image Grid"
+ description = "Arrange multiple images into a grid layout."
+ is_group_process = True # Requires full list to create grid
+ is_output_list = False # Outputs single grid image
+ extra_inputs = [
+ io.Int.Input(
+ "columns",
+ default=4,
+ min=1,
+ max=20,
+ tooltip="Number of columns in the grid.",
+ ),
+ io.Int.Input(
+ "cell_width",
+ default=256,
+ min=32,
+ max=2048,
+ tooltip="Width of each cell in the grid.",
+ ),
+ io.Int.Input(
+ "cell_height",
+ default=256,
+ min=32,
+ max=2048,
+ tooltip="Height of each cell in the grid.",
+ ),
+ io.Int.Input(
+ "padding", default=4, min=0, max=50, tooltip="Padding between images."
+ ),
+ ]
+
+ @classmethod
+ def _group_process(cls, images, columns, cell_width, cell_height, padding):
+ """Arrange images into a grid."""
+ if len(images) == 0:
+ raise ValueError("Cannot create grid from empty image list")
+
+ # Calculate grid dimensions
+ num_images = len(images)
+ rows = (num_images + columns - 1) // columns # Ceiling division
+
+ # Calculate total grid size
+ grid_width = columns * cell_width + (columns - 1) * padding
+ grid_height = rows * cell_height + (rows - 1) * padding
+
+ # Create blank grid
+ grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
+
+ # Place images
+ for idx, img_tensor in enumerate(images):
+ row = idx // columns
+ col = idx % columns
+
+ # Convert to PIL and resize to cell size
+ img = tensor_to_pil(img_tensor)
+ img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
+
+ # Calculate position
+ x = col * (cell_width + padding)
+ y = row * (cell_height + padding)
+
+ # Paste into grid
+ grid.paste(img, (x, y))
+
+ logging.info(
+ f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
+ )
+ return pil_to_tensor(grid)
+
+
+class MergeImageListsNode(ImageProcessingNode):
+ """Merge multiple image lists into a single list."""
+
+ node_id = "MergeImageLists"
+ display_name = "Merge Image Lists"
+ description = "Concatenate multiple image lists into one."
+ is_group_process = True # Receives images as list
+
+ @classmethod
+ def _group_process(cls, images):
+ """Simply return the images list (already merged by input handling)."""
+ # When multiple list inputs are connected, they're concatenated
+ # For now, this is a simple pass-through
+ logging.info(f"Merged image list contains {len(images)} images")
+ return images
+
+
+class MergeTextListsNode(TextProcessingNode):
+ """Merge multiple text lists into a single list."""
+
+ node_id = "MergeTextLists"
+ display_name = "Merge Text Lists"
+ description = "Concatenate multiple text lists into one."
+ is_group_process = True # Receives texts as list
+
+ @classmethod
+ def _group_process(cls, texts):
+ """Simply return the texts list (already merged by input handling)."""
+ # When multiple list inputs are connected, they're concatenated
+ # For now, this is a simple pass-through
+ logging.info(f"Merged text list contains {len(texts)} texts")
+ return texts
+
+
+# ========== Training Dataset Nodes ==========
+
+
+class MakeTrainingDataset(io.ComfyNode):
+ """Encode images with VAE and texts with CLIP to create a training dataset."""
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="MakeTrainingDataset",
+ display_name="Make Training Dataset",
+ category="dataset",
+ is_experimental=True,
+ is_input_list=True, # images and texts as lists
+ inputs=[
+ io.Image.Input("images", tooltip="List of images to encode."),
+ io.Vae.Input(
+ "vae", tooltip="VAE model for encoding images to latents."
+ ),
+ io.Clip.Input(
+ "clip", tooltip="CLIP model for encoding text to conditioning."
+ ),
+ io.String.Input(
+ "texts",
+ optional=True,
+ tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
+ ),
+ ],
+ outputs=[
+ io.Latent.Output(
+ display_name="latents",
+ is_output_list=True,
+ tooltip="List of latent dicts",
+ ),
+ io.Conditioning.Output(
+ display_name="conditioning",
+ is_output_list=True,
+ tooltip="List of conditioning lists",
+ ),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, images, vae, clip, texts=None):
+ # Extract scalars (vae and clip are single values wrapped in lists)
+ vae = vae[0]
+ clip = clip[0]
+
+ # Handle text list
+ num_images = len(images)
+
+ if texts is None or len(texts) == 0:
+ # Treat as [""] for unconditional training
+ texts = [""]
+
+ if len(texts) == 1 and num_images > 1:
+ # Repeat single text for all images
+ texts = texts * num_images
+ elif len(texts) != num_images:
+ raise ValueError(
+ f"Number of texts ({len(texts)}) does not match number of images ({num_images}). "
+ f"Text list should have length {num_images}, 1, or 0."
+ )
+
+ # Encode images with VAE
+ logging.info(f"Encoding {num_images} images with VAE...")
+ latents_list = [] # list[{"samples": tensor}]
+ for img_tensor in images:
+ # img_tensor is [1, H, W, 3]
+ latent_tensor = vae.encode(img_tensor[:, :, :, :3])
+ latents_list.append({"samples": latent_tensor})
+
+ # Encode texts with CLIP
+ logging.info(f"Encoding {len(texts)} texts with CLIP...")
+ conditioning_list = [] # list[list[cond]]
+ for text in texts:
+ if text == "":
+ cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
+ else:
+ tokens = clip.tokenize(text)
+ cond = clip.encode_from_tokens_scheduled(tokens)
+ conditioning_list.append(cond)
+
+ logging.info(
+ f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning."
+ )
+ return io.NodeOutput(latents_list, conditioning_list)
+
+
+class SaveTrainingDataset(io.ComfyNode):
+ """Save encoded training dataset (latents + conditioning) to disk."""
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SaveTrainingDataset",
+ display_name="Save Training Dataset",
+ category="dataset",
+ is_experimental=True,
+ is_output_node=True,
+ is_input_list=True, # Receive lists
+ inputs=[
+ io.Latent.Input(
+ "latents",
+ tooltip="List of latent dicts from MakeTrainingDataset.",
+ ),
+ io.Conditioning.Input(
+ "conditioning",
+ tooltip="List of conditioning lists from MakeTrainingDataset.",
+ ),
+ io.String.Input(
+ "folder_name",
+ default="training_dataset",
+ tooltip="Name of folder to save dataset (inside output directory).",
+ ),
+ io.Int.Input(
+ "shard_size",
+ default=1000,
+ min=1,
+ max=100000,
+ tooltip="Number of samples per shard file.",
+ ),
+ ],
+ outputs=[],
+ )
+
+ @classmethod
+ def execute(cls, latents, conditioning, folder_name, shard_size):
+ # Extract scalars
+ folder_name = folder_name[0]
+ shard_size = shard_size[0]
+
+ # latents: list[{"samples": tensor}]
+ # conditioning: list[list[cond]]
+
+ # Validate lengths match
+ if len(latents) != len(conditioning):
+ raise ValueError(
+ f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
+ f"Something went wrong in dataset preparation."
+ )
+
+ # Create output directory
+ output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Prepare data pairs
+ num_samples = len(latents)
+ num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
+
+ logging.info(
+ f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
+ )
+
+ # Save data in shards
+ for shard_idx in range(num_shards):
+ start_idx = shard_idx * shard_size
+ end_idx = min(start_idx + shard_size, num_samples)
+
+ # Get shard data (list of latent dicts and conditioning lists)
+ shard_data = {
+ "latents": latents[start_idx:end_idx],
+ "conditioning": conditioning[start_idx:end_idx],
+ }
+
+ # Save shard
+ shard_filename = f"shard_{shard_idx:04d}.pkl"
+ shard_path = os.path.join(output_dir, shard_filename)
+
+ with open(shard_path, "wb") as f:
+ torch.save(shard_data, f)
+
+ logging.info(
+ f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
+ )
+
+ # Save metadata
+ metadata = {
+ "num_samples": num_samples,
+ "num_shards": num_shards,
+ "shard_size": shard_size,
+ }
+ metadata_path = os.path.join(output_dir, "metadata.json")
+ with open(metadata_path, "w") as f:
+ json.dump(metadata, f, indent=2)
+
+ logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
+ return io.NodeOutput()
+
+
+class LoadTrainingDataset(io.ComfyNode):
+ """Load encoded training dataset from disk."""
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="LoadTrainingDataset",
+ display_name="Load Training Dataset",
+ category="dataset",
+ is_experimental=True,
+ inputs=[
+ io.String.Input(
+ "folder_name",
+ default="training_dataset",
+ tooltip="Name of folder containing the saved dataset (inside output directory).",
+ ),
+ ],
+ outputs=[
+ io.Latent.Output(
+ display_name="latents",
+ is_output_list=True,
+ tooltip="List of latent dicts",
+ ),
+ io.Conditioning.Output(
+ display_name="conditioning",
+ is_output_list=True,
+ tooltip="List of conditioning lists",
+ ),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, folder_name):
+ # Get dataset directory
+ dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
+
+ if not os.path.exists(dataset_dir):
+ raise ValueError(f"Dataset directory not found: {dataset_dir}")
+
+ # Find all shard files
+ shard_files = sorted(
+ [
+ f
+ for f in os.listdir(dataset_dir)
+ if f.startswith("shard_") and f.endswith(".pkl")
+ ]
+ )
+
+ if not shard_files:
+ raise ValueError(f"No shard files found in {dataset_dir}")
+
+ logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
+
+ # Load all shards
+ all_latents = [] # list[{"samples": tensor}]
+ all_conditioning = [] # list[list[cond]]
+
+ for shard_file in shard_files:
+ shard_path = os.path.join(dataset_dir, shard_file)
+
+ with open(shard_path, "rb") as f:
+ shard_data = torch.load(f, weights_only=True)
+
+ all_latents.extend(shard_data["latents"])
+ all_conditioning.extend(shard_data["conditioning"])
+
+ logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
+
+ logging.info(
+ f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
+ )
+ return io.NodeOutput(all_latents, all_conditioning)
+
+
+# ========== Extension Setup ==========
+
+
+class DatasetExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ # Data loading/saving nodes
+ LoadImageDataSetFromFolderNode,
+ LoadImageTextDataSetFromFolderNode,
+ SaveImageDataSetToFolderNode,
+ SaveImageTextDataSetToFolderNode,
+ # Image transform nodes
+ ResizeImagesByShorterEdgeNode,
+ ResizeImagesByLongerEdgeNode,
+ CenterCropImagesNode,
+ RandomCropImagesNode,
+ NormalizeImagesNode,
+ AdjustBrightnessNode,
+ AdjustContrastNode,
+ ShuffleDatasetNode,
+ ShuffleImageTextDatasetNode,
+ # Text transform nodes
+ TextToLowercaseNode,
+ TextToUppercaseNode,
+ TruncateTextNode,
+ AddTextPrefixNode,
+ AddTextSuffixNode,
+ ReplaceTextNode,
+ StripWhitespaceNode,
+ # Group processing examples
+ ImageDeduplicationNode,
+ ImageGridNode,
+ MergeImageListsNode,
+ MergeTextListsNode,
+ # Training dataset nodes
+ MakeTrainingDataset,
+ SaveTrainingDataset,
+ LoadTrainingDataset,
+ ]
+
+
+async def comfy_entrypoint() -> DatasetExtension:
+ return DatasetExtension()
diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py
index e3ac58447..3429b731e 100644
--- a/comfy_extras/nodes_freelunch.py
+++ b/comfy_extras/nodes_freelunch.py
@@ -2,6 +2,8 @@
import torch
import logging
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, IO
def Fourier_filter(x, threshold, scale):
# FFT
@@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
return x_filtered.to(x.dtype)
-class FreeU:
+class FreeU(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "model": ("MODEL",),
- "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
- "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
- }}
- RETURN_TYPES = ("MODEL",)
- FUNCTION = "patch"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="FreeU",
+ category="model_patches/unet",
+ inputs=[
+ IO.Model.Input("model"),
+ IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
+ ],
+ outputs=[
+ IO.Model.Output(),
+ ],
+ )
- CATEGORY = "model_patches/unet"
-
- def patch(self, model, b1, b2, s1, s2):
+ @classmethod
+ def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@@ -59,23 +66,31 @@ class FreeU:
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
- return (m, )
+ return IO.NodeOutput(m)
-class FreeU_V2:
+ patch = execute # TODO: remove
+
+
+class FreeU_V2(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "model": ("MODEL",),
- "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
- "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
- "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
- }}
- RETURN_TYPES = ("MODEL",)
- FUNCTION = "patch"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="FreeU_V2",
+ category="model_patches/unet",
+ inputs=[
+ IO.Model.Input("model"),
+ IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
+ IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
+ ],
+ outputs=[
+ IO.Model.Output(),
+ ],
+ )
- CATEGORY = "model_patches/unet"
-
- def patch(self, model, b1, b2, s1, s2):
+ @classmethod
+ def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@@ -105,9 +120,19 @@ class FreeU_V2:
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
- return (m, )
+ return IO.NodeOutput(m)
-NODE_CLASS_MAPPINGS = {
- "FreeU": FreeU,
- "FreeU_V2": FreeU_V2,
-}
+ patch = execute # TODO: remove
+
+
+class FreelunchExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ FreeU,
+ FreeU_V2,
+ ]
+
+
+async def comfy_entrypoint() -> FreelunchExtension:
+ return FreelunchExtension()
diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py
new file mode 100644
index 000000000..9cb234be1
--- /dev/null
+++ b/comfy_extras/nodes_kandinsky5.py
@@ -0,0 +1,136 @@
+import nodes
+import node_helpers
+import torch
+import comfy.model_management
+import comfy.utils
+
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, io
+
+
+class Kandinsky5ImageToVideo(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="Kandinsky5ImageToVideo",
+ category="conditioning/video_models",
+ inputs=[
+ io.Conditioning.Input("positive"),
+ io.Conditioning.Input("negative"),
+ io.Vae.Input("vae"),
+ io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
+ io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
+ io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
+ io.Image.Input("start_image", optional=True),
+ ],
+ outputs=[
+ io.Conditioning.Output(display_name="positive"),
+ io.Conditioning.Output(display_name="negative"),
+ io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
+ io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
+ cond_latent_out = {}
+ if start_image is not None:
+ start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+ encoded = vae.encode(start_image[:, :, :, :3])
+ cond_latent_out["samples"] = encoded
+
+ mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
+ mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
+
+ positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
+ negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
+
+ out_latent = {}
+ out_latent["samples"] = latent
+ return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
+
+
+def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
+ source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
+ source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
+
+ reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
+ reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
+
+ # normalization
+ normalized = (source - source_mean) / (source_std + 1e-8)
+ normalized = normalized * reference_std + reference_mean
+
+ return normalized
+
+
+class NormalizeVideoLatentStart(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="NormalizeVideoLatentStart",
+ category="conditioning/video_models",
+ description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
+ inputs=[
+ io.Latent.Input("latent"),
+ io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
+ io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
+ ],
+ outputs=[
+ io.Latent.Output(display_name="latent"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
+ if latent["samples"].shape[2] <= 1:
+ return io.NodeOutput(latent)
+ s = latent.copy()
+ samples = latent["samples"].clone()
+
+ first_frames = samples[:, :, :start_frame_count]
+ reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
+ normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
+
+ samples[:, :, :start_frame_count] = normalized_first_frames
+ s["samples"] = samples
+ return io.NodeOutput(s)
+
+
+class CLIPTextEncodeKandinsky5(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="CLIPTextEncodeKandinsky5",
+ category="advanced/conditioning/kandinsky5",
+ inputs=[
+ io.Clip.Input("clip"),
+ io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
+ io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
+ ],
+ outputs=[
+ io.Conditioning.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
+ tokens = clip.tokenize(clip_l)
+ tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
+
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
+
+
+class Kandinsky5Extension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ Kandinsky5ImageToVideo,
+ NormalizeVideoLatentStart,
+ CLIPTextEncodeKandinsky5,
+ ]
+
+async def comfy_entrypoint() -> Kandinsky5Extension:
+ return Kandinsky5Extension()
diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py
index d2df07ff9..e439b18ef 100644
--- a/comfy_extras/nodes_latent.py
+++ b/comfy_extras/nodes_latent.py
@@ -4,7 +4,7 @@ import torch
import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
-
+import logging
def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]:
@@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode):
return luminance * sharpened
return io.NodeOutput(sharpen)
+class ReplaceVideoLatentFrames(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="ReplaceVideoLatentFrames",
+ category="latent/batch",
+ inputs=[
+ io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
+ io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
+ io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
+ ],
+ outputs=[
+ io.Latent.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, destination, index, source=None) -> io.NodeOutput:
+ if source is None:
+ return io.NodeOutput(destination)
+ dest_frames = destination["samples"].shape[2]
+ source_frames = source["samples"].shape[2]
+ if index < 0:
+ index = dest_frames + index
+ if index > dest_frames:
+ logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
+ return io.NodeOutput(destination)
+ if index + source_frames > dest_frames:
+ logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
+ return io.NodeOutput(destination)
+ s = source.copy()
+ s_source = source["samples"]
+ s_destination = destination["samples"].clone()
+ s_destination[:, :, index:index + s_source.shape[2]] = s_source
+ s["samples"] = s_destination
+ return io.NodeOutput(s)
class LatentExtension(ComfyExtension):
@override
@@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension):
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
+ ReplaceVideoLatentFrames
]
diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py
index 899608149..545588ef8 100644
--- a/comfy_extras/nodes_load_3d.py
+++ b/comfy_extras/nodes_load_3d.py
@@ -2,8 +2,8 @@ import nodes
import folder_paths
import os
-from comfy.comfy_types import IO
-from comfy_api.input_impl import VideoFromFile
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
from pathlib import Path
@@ -11,9 +11,9 @@ from pathlib import Path
def normalize_path(path):
return path.replace('\\', '/')
-class Load3D():
+class Load3D(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
+ def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
@@ -26,157 +26,84 @@ class Load3D():
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
]
+ return IO.Schema(
+ node_id="Load3D",
+ display_name="Load 3D & Animation",
+ category="3d",
+ is_experimental=True,
+ inputs=[
+ IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
+ IO.Load3D.Input("image"),
+ IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
+ IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
+ ],
+ outputs=[
+ IO.Image.Output(display_name="image"),
+ IO.Mask.Output(display_name="mask"),
+ IO.String.Output(display_name="mesh_path"),
+ IO.Image.Output(display_name="normal"),
+ IO.Load3DCamera.Output(display_name="camera_info"),
+ IO.Video.Output(display_name="recording_video"),
+ ],
+ )
- return {"required": {
- "model_file": (sorted(files), {"file_upload": True}),
- "image": ("LOAD_3D", {}),
- "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- }}
-
- RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
- RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- CATEGORY = "3d"
-
- def process(self, model_file, image, **kwargs):
+ @classmethod
+ def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
- lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
- lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
- video = VideoFromFile(recording_video_path)
+ video = InputImpl.VideoFromFile(recording_video_path)
- return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
+ return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
-class Load3DAnimation():
+ process = execute # TODO: remove
+
+
+class Preview3D(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Preview3D",
+ display_name="Preview 3D & Animation",
+ category="3d",
+ is_experimental=True,
+ is_output_node=True,
+ inputs=[
+ IO.String.Input("model_file", default="", multiline=False),
+ IO.Load3DCamera.Input("camera_info", optional=True),
+ IO.Image.Input("bg_image", optional=True),
+ ],
+ outputs=[],
+ )
- os.makedirs(input_dir, exist_ok=True)
+ @classmethod
+ def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
+ camera_info = kwargs.get("camera_info", None)
+ bg_image = kwargs.get("bg_image", None)
+ return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
- input_path = Path(input_dir)
- base_path = Path(folder_paths.get_input_directory())
+ process = execute # TODO: remove
- files = [
- normalize_path(str(file_path.relative_to(base_path)))
- for file_path in input_path.rglob("*")
- if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
+
+class Load3DExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ Load3D,
+ Preview3D,
]
- return {"required": {
- "model_file": (sorted(files), {"file_upload": True}),
- "image": ("LOAD_3D_ANIMATION", {}),
- "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- }}
- RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
- RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- CATEGORY = "3d"
-
- def process(self, model_file, image, **kwargs):
- image_path = folder_paths.get_annotated_filepath(image['image'])
- mask_path = folder_paths.get_annotated_filepath(image['mask'])
- normal_path = folder_paths.get_annotated_filepath(image['normal'])
-
- load_image_node = nodes.LoadImage()
- output_image, ignore_mask = load_image_node.load_image(image=image_path)
- ignore_image, output_mask = load_image_node.load_image(image=mask_path)
- normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
-
- video = None
-
- if image['recording'] != "":
- recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
-
- video = VideoFromFile(recording_video_path)
-
- return output_image, output_mask, model_file, normal_image, image['camera_info'], video
-
-class Preview3D():
- @classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "model_file": ("STRING", {"default": "", "multiline": False}),
- },
- "optional": {
- "camera_info": ("LOAD3D_CAMERA", {})
- }}
-
- OUTPUT_NODE = True
- RETURN_TYPES = ()
-
- CATEGORY = "3d"
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- def process(self, model_file, **kwargs):
- camera_info = kwargs.get("camera_info", None)
-
- return {
- "ui": {
- "result": [model_file, camera_info]
- }
- }
-
-class Preview3DAnimation():
- @classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "model_file": ("STRING", {"default": "", "multiline": False}),
- },
- "optional": {
- "camera_info": ("LOAD3D_CAMERA", {})
- }}
-
- OUTPUT_NODE = True
- RETURN_TYPES = ()
-
- CATEGORY = "3d"
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- def process(self, model_file, **kwargs):
- camera_info = kwargs.get("camera_info", None)
-
- return {
- "ui": {
- "result": [model_file, camera_info]
- }
- }
-
-NODE_CLASS_MAPPINGS = {
- "Load3D": Load3D,
- "Load3DAnimation": Load3DAnimation,
- "Preview3D": Preview3D,
- "Preview3DAnimation": Preview3DAnimation
-}
-
-NODE_DISPLAY_NAME_MAPPINGS = {
- "Load3D": "Load 3D",
- "Load3DAnimation": "Load 3D - Animation",
- "Preview3D": "Preview 3D",
- "Preview3DAnimation": "Preview 3D - Animation"
-}
+async def comfy_entrypoint() -> Load3DExtension:
+ return Load3DExtension()
diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py
new file mode 100644
index 000000000..95a6ba788
--- /dev/null
+++ b/comfy_extras/nodes_logic.py
@@ -0,0 +1,155 @@
+from typing import TypedDict
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, io
+from comfy_api.latest import _io
+
+
+
+class SwitchNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = io.MatchType.Template("switch")
+ return io.Schema(
+ node_id="ComfySwitchNode",
+ display_name="Switch",
+ category="logic",
+ is_experimental=True,
+ inputs=[
+ io.Boolean.Input("switch"),
+ io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
+ io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
+ ],
+ outputs=[
+ io.MatchType.Output(template=template, display_name="output"),
+ ],
+ )
+
+ @classmethod
+ def check_lazy_status(cls, switch, on_false=..., on_true=...):
+ # We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
+ # This trick allows us to ignore the value of the switch and still be able to run execute().
+
+ # One of the inputs may be missing, in which case we need to evaluate the other input
+ if on_false is ...:
+ return ["on_true"]
+ if on_true is ...:
+ return ["on_false"]
+ # Normal lazy switch operation
+ if switch and on_true is None:
+ return ["on_true"]
+ if not switch and on_false is None:
+ return ["on_false"]
+
+ @classmethod
+ def validate_inputs(cls, switch, on_false=..., on_true=...):
+ # This check happens before check_lazy_status(), so we can eliminate the case where
+ # both inputs are missing.
+ if on_false is ... and on_true is ...:
+ return "At least one of on_false or on_true must be connected to Switch node"
+ return True
+
+ @classmethod
+ def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
+ if on_true is ...:
+ return io.NodeOutput(on_false)
+ if on_false is ...:
+ return io.NodeOutput(on_true)
+ return io.NodeOutput(on_true if switch else on_false)
+
+
+class DCTestNode(io.ComfyNode):
+ class DCValues(TypedDict):
+ combo: str
+ string: str
+ integer: int
+ image: io.Image.Type
+ subcombo: dict[str]
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="DCTestNode",
+ display_name="DCTest",
+ category="logic",
+ is_output_node=True,
+ inputs=[_io.DynamicCombo.Input("combo", options=[
+ _io.DynamicCombo.Option("option1", [io.String.Input("string")]),
+ _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
+ _io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
+ _io.DynamicCombo.Option("option4", [
+ _io.DynamicCombo.Input("subcombo", options=[
+ _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
+ _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
+ ])
+ ])]
+ )],
+ outputs=[io.AnyType.Output()],
+ )
+
+ @classmethod
+ def execute(cls, combo: DCValues) -> io.NodeOutput:
+ combo_val = combo["combo"]
+ if combo_val == "option1":
+ return io.NodeOutput(combo["string"])
+ elif combo_val == "option2":
+ return io.NodeOutput(combo["integer"])
+ elif combo_val == "option3":
+ return io.NodeOutput(combo["image"])
+ elif combo_val == "option4":
+ return io.NodeOutput(f"{combo['subcombo']}")
+ else:
+ raise ValueError(f"Invalid combo: {combo_val}")
+
+
+class AutogrowNamesTestNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
+ return io.Schema(
+ node_id="AutogrowNamesTestNode",
+ display_name="AutogrowNamesTest",
+ category="logic",
+ inputs=[
+ _io.Autogrow.Input("autogrow", template=template)
+ ],
+ outputs=[io.String.Output()],
+ )
+
+ @classmethod
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
+ vals = list(autogrow.values())
+ combined = ",".join([str(x) for x in vals])
+ return io.NodeOutput(combined)
+
+class AutogrowPrefixTestNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
+ return io.Schema(
+ node_id="AutogrowPrefixTestNode",
+ display_name="AutogrowPrefixTest",
+ category="logic",
+ inputs=[
+ _io.Autogrow.Input("autogrow", template=template)
+ ],
+ outputs=[io.String.Output()],
+ )
+
+ @classmethod
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
+ vals = list(autogrow.values())
+ combined = ",".join([str(x) for x in vals])
+ return io.NodeOutput(combined)
+
+class LogicExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ # SwitchNode,
+ # DCTestNode,
+ # AutogrowNamesTestNode,
+ # AutogrowPrefixTestNode,
+ ]
+
+async def comfy_entrypoint() -> LogicExtension:
+ return LogicExtension()
diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py
index a5e405008..290e6f55e 100644
--- a/comfy_extras/nodes_mask.py
+++ b/comfy_extras/nodes_mask.py
@@ -3,11 +3,10 @@ import scipy.ndimage
import torch
import comfy.utils
import node_helpers
-import folder_paths
-import random
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, IO, UI
import nodes
-from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
@@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
-class LatentCompositeMasked:
+class LatentCompositeMasked(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "destination": ("LATENT",),
- "source": ("LATENT",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
- "resize_source": ("BOOLEAN", {"default": False}),
- },
- "optional": {
- "mask": ("MASK",),
- }
- }
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "composite"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="LatentCompositeMasked",
+ category="latent",
+ inputs=[
+ IO.Latent.Input("destination"),
+ IO.Latent.Input("source"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
+ IO.Boolean.Input("resize_source", default=False),
+ IO.Mask.Input("mask", optional=True),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
- CATEGORY = "latent"
-
- def composite(self, destination, source, x, y, resize_source, mask = None):
+ @classmethod
+ def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
output = destination.copy()
destination = destination["samples"].clone()
source = source["samples"]
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
- return (output,)
+ return IO.NodeOutput(output)
-class ImageCompositeMasked:
+ composite = execute # TODO: remove
+
+
+class ImageCompositeMasked(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "destination": ("IMAGE",),
- "source": ("IMAGE",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "resize_source": ("BOOLEAN", {"default": False}),
- },
- "optional": {
- "mask": ("MASK",),
- }
- }
- RETURN_TYPES = ("IMAGE",)
- FUNCTION = "composite"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageCompositeMasked",
+ category="image",
+ inputs=[
+ IO.Image.Input("destination"),
+ IO.Image.Input("source"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Boolean.Input("resize_source", default=False),
+ IO.Mask.Input("mask", optional=True),
+ ],
+ outputs=[IO.Image.Output()],
+ )
- CATEGORY = "image"
-
- def composite(self, destination, source, x, y, resize_source, mask = None):
+ @classmethod
+ def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
- return (output,)
+ return IO.NodeOutput(output)
-class MaskToImage:
+ composite = execute # TODO: remove
+
+
+class MaskToImage(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "mask": ("MASK",),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="MaskToImage",
+ display_name="Convert Mask to Image",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ ],
+ outputs=[IO.Image.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("IMAGE",)
- FUNCTION = "mask_to_image"
-
- def mask_to_image(self, mask):
+ @classmethod
+ def execute(cls, mask) -> IO.NodeOutput:
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
- return (result,)
+ return IO.NodeOutput(result)
-class ImageToMask:
+ mask_to_image = execute # TODO: remove
+
+
+class ImageToMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "image": ("IMAGE",),
- "channel": (["red", "green", "blue", "alpha"],),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageToMask",
+ display_name="Convert Image to Mask",
+ category="mask",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
- FUNCTION = "image_to_mask"
-
- def image_to_mask(self, image, channel):
+ @classmethod
+ def execute(cls, image, channel) -> IO.NodeOutput:
channels = ["red", "green", "blue", "alpha"]
mask = image[:, :, :, channels.index(channel)]
- return (mask,)
+ return IO.NodeOutput(mask)
-class ImageColorToMask:
+ image_to_mask = execute # TODO: remove
+
+
+class ImageColorToMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "image": ("IMAGE",),
- "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageColorToMask",
+ category="mask",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
- FUNCTION = "image_to_mask"
-
- def image_to_mask(self, image, color):
+ @classmethod
+ def execute(cls, image, color) -> IO.NodeOutput:
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
mask = torch.where(temp == color, 1.0, 0).float()
- return (mask,)
+ return IO.NodeOutput(mask)
-class SolidMask:
+ image_to_mask = execute # TODO: remove
+
+
+class SolidMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
- "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SolidMask",
+ category="mask",
+ inputs=[
+ IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
+ IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "solid"
-
- def solid(self, value, width, height):
+ @classmethod
+ def execute(cls, value, width, height) -> IO.NodeOutput:
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
- return (out,)
+ return IO.NodeOutput(out)
-class InvertMask:
+ solid = execute # TODO: remove
+
+
+class InvertMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="InvertMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "invert"
-
- def invert(self, mask):
+ @classmethod
+ def execute(cls, mask) -> IO.NodeOutput:
out = 1.0 - mask
- return (out,)
+ return IO.NodeOutput(out)
-class CropMask:
+ invert = execute # TODO: remove
+
+
+class CropMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="CropMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "crop"
-
- def crop(self, mask, x, y, width, height):
+ @classmethod
+ def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = mask[:, y:y + height, x:x + width]
- return (out,)
+ return IO.NodeOutput(out)
-class MaskComposite:
+ crop = execute # TODO: remove
+
+
+class MaskComposite(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "destination": ("MASK",),
- "source": ("MASK",),
- "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="MaskComposite",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("destination"),
+ IO.Mask.Input("source"),
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "combine"
-
- def combine(self, destination, source, x, y, operation):
+ @classmethod
+ def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
@@ -267,28 +277,29 @@ class MaskComposite:
output = torch.clamp(output, 0.0, 1.0)
- return (output,)
+ return IO.NodeOutput(output)
-class FeatherMask:
+ combine = execute # TODO: remove
+
+
+class FeatherMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="FeatherMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "feather"
-
- def feather(self, mask, left, top, right, bottom):
+ @classmethod
+ def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[-1])
@@ -312,26 +323,28 @@ class FeatherMask:
feather_rate = (y + 1) / bottom
output[:, -y, :] *= feather_rate
- return (output,)
+ return IO.NodeOutput(output)
-class GrowMask:
+ feather = execute # TODO: remove
+
+
+class GrowMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "mask": ("MASK",),
- "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
- "tapered_corners": ("BOOLEAN", {"default": True}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="GrowMask",
+ display_name="Grow Mask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
+ IO.Boolean.Input("tapered_corners", default=True),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
-
- FUNCTION = "expand_mask"
-
- def expand_mask(self, mask, expand, tapered_corners):
+ @classmethod
+ def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
@@ -347,69 +360,74 @@ class GrowMask:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
- return (torch.stack(out, dim=0),)
+ return IO.NodeOutput(torch.stack(out, dim=0))
-class ThresholdMask:
+ expand_mask = execute # TODO: remove
+
+
+class ThresholdMask(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "mask": ("MASK",),
- "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
- }
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ThresholdMask",
+ category="mask",
+ inputs=[
+ IO.Mask.Input("mask"),
+ IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
+ ],
+ outputs=[IO.Mask.Output()],
+ )
- CATEGORY = "mask"
-
- RETURN_TYPES = ("MASK",)
- FUNCTION = "image_to_mask"
-
- def image_to_mask(self, mask, value):
+ @classmethod
+ def execute(cls, mask, value) -> IO.NodeOutput:
mask = (mask > value).float()
- return (mask,)
+ return IO.NodeOutput(mask)
+
+ image_to_mask = execute # TODO: remove
+
# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
-class MaskPreview(nodes.SaveImage):
- def __init__(self):
- self.output_dir = folder_paths.get_temp_directory()
- self.type = "temp"
- self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
- self.compress_level = 4
+class MaskPreview(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="MaskPreview",
+ display_name="Preview Mask",
+ category="mask",
+ description="Saves the input images to your ComfyUI output directory.",
+ inputs=[
+ IO.Mask.Input("mask"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {"mask": ("MASK",), },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
-
- FUNCTION = "execute"
- CATEGORY = "mask"
-
- def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
- preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
- return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
+ def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
+ return IO.NodeOutput(ui=UI.PreviewMask(mask))
-NODE_CLASS_MAPPINGS = {
- "LatentCompositeMasked": LatentCompositeMasked,
- "ImageCompositeMasked": ImageCompositeMasked,
- "MaskToImage": MaskToImage,
- "ImageToMask": ImageToMask,
- "ImageColorToMask": ImageColorToMask,
- "SolidMask": SolidMask,
- "InvertMask": InvertMask,
- "CropMask": CropMask,
- "MaskComposite": MaskComposite,
- "FeatherMask": FeatherMask,
- "GrowMask": GrowMask,
- "ThresholdMask": ThresholdMask,
- "MaskPreview": MaskPreview
-}
+class MaskExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ LatentCompositeMasked,
+ ImageCompositeMasked,
+ MaskToImage,
+ ImageToMask,
+ ImageColorToMask,
+ SolidMask,
+ InvertMask,
+ CropMask,
+ MaskComposite,
+ FeatherMask,
+ GrowMask,
+ ThresholdMask,
+ MaskPreview,
+ ]
-NODE_DISPLAY_NAME_MAPPINGS = {
- "ImageToMask": "Convert Image to Mask",
- "MaskToImage": "Convert Mask to Image",
-}
+
+async def comfy_entrypoint() -> MaskExtension:
+ return MaskExtension()
diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py
index f7ca9699d..dec2ae841 100644
--- a/comfy_extras/nodes_model_downscale.py
+++ b/comfy_extras/nodes_model_downscale.py
@@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
return io.NodeOutput(m)
-NODE_DISPLAY_NAME_MAPPINGS = {
- # Sampling
- "PatchModelAddDownscale": "",
-}
-
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py
index 783c59b6b..c61810dbf 100644
--- a/comfy_extras/nodes_model_patch.py
+++ b/comfy_extras/nodes_model_patch.py
@@ -6,6 +6,7 @@ import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
+import comfy.ldm.lumina.controlnet
class BlockWiseControlBlock(torch.nn.Module):
@@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
return embedding
+def z_image_convert(sd):
+ replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
+ ".attention.norm_k.weight": ".attention.k_norm.weight",
+ ".attention.norm_q.weight": ".attention.q_norm.weight",
+ ".attention.to_out.0.weight": ".attention.out.weight"
+ }
+
+ out_sd = {}
+ for k in sorted(sd.keys()):
+ w = sd[k]
+
+ k_out = k
+ if k_out.endswith(".attention.to_k.weight"):
+ cc = [w]
+ continue
+ if k_out.endswith(".attention.to_q.weight"):
+ cc = [w] + cc
+ continue
+ if k_out.endswith(".attention.to_v.weight"):
+ cc = cc + [w]
+ w = torch.cat(cc, dim=0)
+ k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
+
+ for r, rr in replace_keys.items():
+ k_out = k_out.replace(r, rr)
+ out_sd[k_out] = w
+
+ return out_sd
+
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@@ -211,6 +241,9 @@ class ModelPatchLoader:
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
+ elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
+ sd = z_image_convert(sd)
+ model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
def models(self):
return [self.model_patch]
+class ZImageControlPatch:
+ def __init__(self, model_patch, vae, image, strength):
+ self.model_patch = model_patch
+ self.vae = vae
+ self.image = image
+ self.strength = strength
+ self.encoded_image = self.encode_latent_cond(image)
+ self.encoded_image_size = (image.shape[1], image.shape[2])
+ self.temp_data = None
+
+ def encode_latent_cond(self, image):
+ latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
+ return latent_image
+
+ def __call__(self, kwargs):
+ x = kwargs.get("x")
+ img = kwargs.get("img")
+ txt = kwargs.get("txt")
+ pe = kwargs.get("pe")
+ vec = kwargs.get("vec")
+ block_index = kwargs.get("block_index")
+ spacial_compression = self.vae.spacial_compression_encode()
+ if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
+ image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
+ loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
+ self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
+ self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
+ comfy.model_management.load_models_gpu(loaded_models)
+
+ cnet_index = (block_index // 5)
+ cnet_index_float = (block_index / 5)
+
+ kwargs.pop("img") # we do ops in place
+ kwargs.pop("txt")
+
+ cnet_blocks = self.model_patch.model.n_control_layers
+ if cnet_index_float > (cnet_blocks - 1):
+ self.temp_data = None
+ return kwargs
+
+ if self.temp_data is None or self.temp_data[0] > cnet_index:
+ self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
+
+ while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
+ next_layer = self.temp_data[0] + 1
+ self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
+
+ if cnet_index_float == self.temp_data[0]:
+ img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
+ if cnet_blocks == self.temp_data[0] + 1:
+ self.temp_data = None
+
+ return kwargs
+
+ def to(self, device_or_dtype):
+ if isinstance(device_or_dtype, torch.device):
+ self.encoded_image = self.encoded_image.to(device_or_dtype)
+ self.temp_data = None
+ return self
+
+ def models(self):
+ return [self.model_patch]
+
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
@@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet:
mask = mask.unsqueeze(2)
mask = 1.0 - mask
- model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
+ if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
+ model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
+ else:
+ model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py
index 9e6ec6780..19b8baaf4 100644
--- a/comfy_extras/nodes_train.py
+++ b/comfy_extras/nodes_train.py
@@ -1,15 +1,13 @@
-import datetime
-import json
import logging
import os
import numpy as np
import safetensors
import torch
-from PIL import Image, ImageDraw, ImageFont
-from PIL.PngImagePlugin import PngInfo
import torch.utils.checkpoint
-import tqdm
+from tqdm.auto import trange
+from PIL import Image, ImageDraw, ImageFont
+from typing_extensions import override
import comfy.samplers
import comfy.sd
@@ -18,9 +16,9 @@ import comfy.model_management
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
-from comfy.cli_args import args
-from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters, adapter_maps
+from comfy_api.latest import ComfyExtension, io, ui
+from comfy.utils import ProgressBar
def make_batch_extra_option_dict(d, indicies, full_size=None):
@@ -56,7 +54,18 @@ def process_cond_list(d, prefix=""):
class TrainSampler(comfy.samplers.Sampler):
- def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
+ def __init__(
+ self,
+ loss_fn,
+ optimizer,
+ loss_callback=None,
+ batch_size=1,
+ grad_acc=1,
+ total_steps=1,
+ seed=0,
+ training_dtype=torch.bfloat16,
+ real_dataset=None,
+ ):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
@@ -65,54 +74,138 @@ class TrainSampler(comfy.samplers.Sampler):
self.grad_acc = grad_acc
self.seed = seed
self.training_dtype = training_dtype
+ self.real_dataset: list[torch.Tensor] | None = real_dataset
- def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
+ def fwd_bwd(
+ self,
+ model_wrap,
+ batch_sigmas,
+ batch_noise,
+ batch_latent,
+ cond,
+ indicies,
+ extra_args,
+ dataset_size,
+ bwd=True,
+ ):
+ xt = model_wrap.inner_model.model_sampling.noise_scaling(
+ batch_sigmas, batch_noise, batch_latent, False
+ )
+ x0 = model_wrap.inner_model.model_sampling.noise_scaling(
+ torch.zeros_like(batch_sigmas),
+ torch.zeros_like(batch_noise),
+ batch_latent,
+ False,
+ )
+
+ model_wrap.conds["positive"] = [cond[i] for i in indicies]
+ batch_extra_args = make_batch_extra_option_dict(
+ extra_args, indicies, full_size=dataset_size
+ )
+
+ with torch.autocast(xt.device.type, dtype=self.training_dtype):
+ x0_pred = model_wrap(
+ xt.requires_grad_(True),
+ batch_sigmas.requires_grad_(True),
+ **batch_extra_args,
+ )
+ loss = self.loss_fn(x0_pred, x0)
+ if bwd:
+ bwd_loss = loss / self.grad_acc
+ bwd_loss.backward()
+ return loss
+
+ def sample(
+ self,
+ model_wrap,
+ sigmas,
+ extra_args,
+ callback,
+ noise,
+ latent_image=None,
+ denoise_mask=None,
+ disable_pbar=False,
+ ):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
- for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
- noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
- indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
-
- batch_latent = torch.stack([latent_image[i] for i in indicies])
- batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
- batch_sigmas = [
- model_wrap.inner_model.model_sampling.percent_to_sigma(
- torch.rand((1,)).item()
- ) for _ in range(min(self.batch_size, dataset_size))
- ]
- batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
-
- xt = model_wrap.inner_model.model_sampling.noise_scaling(
- batch_sigmas,
- batch_noise,
- batch_latent,
- False
+ ui_pbar = ProgressBar(self.total_steps)
+ for i in (
+ pbar := trange(
+ self.total_steps,
+ desc="Training LoRA",
+ smoothing=0.01,
+ disable=not comfy.utils.PROGRESS_BAR_ENABLED,
)
- x0 = model_wrap.inner_model.model_sampling.noise_scaling(
- torch.zeros_like(batch_sigmas),
- torch.zeros_like(batch_noise),
- batch_latent,
- False
+ ):
+ noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
+ self.seed + i * 1000
)
+ indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
- model_wrap.conds["positive"] = [
- cond[i] for i in indicies
- ]
- batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
+ if self.real_dataset is None:
+ batch_latent = torch.stack([latent_image[i] for i in indicies])
+ batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
+ batch_latent.device
+ )
+ batch_sigmas = [
+ model_wrap.inner_model.model_sampling.percent_to_sigma(
+ torch.rand((1,)).item()
+ )
+ for _ in range(min(self.batch_size, dataset_size))
+ ]
+ batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
- with torch.autocast(xt.device.type, dtype=self.training_dtype):
- x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
- loss = self.loss_fn(x0_pred, x0)
- loss.backward()
- if self.loss_callback:
- self.loss_callback(loss.item())
- pbar.set_postfix({"loss": f"{loss.item():.4f}"})
+ loss = self.fwd_bwd(
+ model_wrap,
+ batch_sigmas,
+ batch_noise,
+ batch_latent,
+ cond,
+ indicies,
+ extra_args,
+ dataset_size,
+ bwd=True,
+ )
+ if self.loss_callback:
+ self.loss_callback(loss.item())
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
+ else:
+ total_loss = 0
+ for index in indicies:
+ single_latent = self.real_dataset[index].to(latent_image)
+ batch_noise = noisegen.generate_noise(
+ {"samples": single_latent}
+ ).to(single_latent.device)
+ batch_sigmas = (
+ model_wrap.inner_model.model_sampling.percent_to_sigma(
+ torch.rand((1,)).item()
+ )
+ )
+ batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
+ loss = self.fwd_bwd(
+ model_wrap,
+ batch_sigmas,
+ batch_noise,
+ single_latent,
+ cond,
+ [index],
+ extra_args,
+ dataset_size,
+ bwd=False,
+ )
+ total_loss += loss
+ total_loss = total_loss / self.grad_acc / len(indicies)
+ total_loss.backward()
+ if self.loss_callback:
+ self.loss_callback(total_loss.item())
+ pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
- if (i+1) % self.grad_acc == 0:
+ if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
+ ui_pbar.update(1)
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
@@ -134,233 +227,6 @@ class BiasDiff(torch.nn.Module):
return self.passive_memory_usage()
-def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
- """Utility function to load and process a list of images.
-
- Args:
- image_files: List of image filenames
- input_dir: Base directory containing the images
- resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
-
- Returns:
- torch.Tensor: Batch of processed images
- """
- if not image_files:
- raise ValueError("No valid images found in input")
-
- output_images = []
-
- for file in image_files:
- image_path = os.path.join(input_dir, file)
- img = node_helpers.pillow(Image.open, image_path)
-
- if img.mode == "I":
- img = img.point(lambda i: i * (1 / 255))
- img = img.convert("RGB")
-
- if w is None and h is None:
- w, h = img.size[0], img.size[1]
-
- # Resize image to first image
- if img.size[0] != w or img.size[1] != h:
- if resize_method == "Stretch":
- img = img.resize((w, h), Image.Resampling.LANCZOS)
- elif resize_method == "Crop":
- img = img.crop((0, 0, w, h))
- elif resize_method == "Pad":
- img = img.resize((w, h), Image.Resampling.LANCZOS)
- elif resize_method == "None":
- raise ValueError(
- "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
- )
-
- img_array = np.array(img).astype(np.float32) / 255.0
- img_tensor = torch.from_numpy(img_array)[None,]
- output_images.append(img_tensor)
-
- return torch.cat(output_images, dim=0)
-
-
-class LoadImageSetNode:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "images": (
- [
- f
- for f in os.listdir(folder_paths.get_input_directory())
- if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
- ],
- {"image_upload": True, "allow_batch": True},
- )
- },
- "optional": {
- "resize_method": (
- ["None", "Stretch", "Crop", "Pad"],
- {"default": "None"},
- ),
- },
- }
-
- INPUT_IS_LIST = True
- RETURN_TYPES = ("IMAGE",)
- FUNCTION = "load_images"
- CATEGORY = "loaders"
- EXPERIMENTAL = True
- DESCRIPTION = "Loads a batch of images from a directory for training."
-
- @classmethod
- def VALIDATE_INPUTS(s, images, resize_method):
- filenames = images[0] if isinstance(images[0], list) else images
-
- for image in filenames:
- if not folder_paths.exists_annotated_filepath(image):
- return "Invalid image file: {}".format(image)
- return True
-
- def load_images(self, input_files, resize_method):
- input_dir = folder_paths.get_input_directory()
- valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
- image_files = [
- f
- for f in input_files
- if any(f.lower().endswith(ext) for ext in valid_extensions)
- ]
- output_tensor = load_and_process_images(image_files, input_dir, resize_method)
- return (output_tensor,)
-
-
-class LoadImageSetFromFolderNode:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
- },
- "optional": {
- "resize_method": (
- ["None", "Stretch", "Crop", "Pad"],
- {"default": "None"},
- ),
- },
- }
-
- RETURN_TYPES = ("IMAGE",)
- FUNCTION = "load_images"
- CATEGORY = "loaders"
- EXPERIMENTAL = True
- DESCRIPTION = "Loads a batch of images from a directory for training."
-
- def load_images(self, folder, resize_method):
- sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
- valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
- image_files = [
- f
- for f in os.listdir(sub_input_dir)
- if any(f.lower().endswith(ext) for ext in valid_extensions)
- ]
- output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
- return (output_tensor,)
-
-
-class LoadImageTextSetFromFolderNode:
- @classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
- "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
- },
- "optional": {
- "resize_method": (
- ["None", "Stretch", "Crop", "Pad"],
- {"default": "None"},
- ),
- "width": (
- IO.INT,
- {
- "default": -1,
- "min": -1,
- "max": 10000,
- "step": 1,
- "tooltip": "The width to resize the images to. -1 means use the original width.",
- },
- ),
- "height": (
- IO.INT,
- {
- "default": -1,
- "min": -1,
- "max": 10000,
- "step": 1,
- "tooltip": "The height to resize the images to. -1 means use the original height.",
- },
- )
- },
- }
-
- RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
- FUNCTION = "load_images"
- CATEGORY = "loaders"
- EXPERIMENTAL = True
- DESCRIPTION = "Loads a batch of images and caption from a directory for training."
-
- def load_images(self, folder, clip, resize_method, width=None, height=None):
- if clip is None:
- raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
-
- logging.info(f"Loading images from folder: {folder}")
-
- sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
- valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
-
- image_files = []
- for item in os.listdir(sub_input_dir):
- path = os.path.join(sub_input_dir, item)
- if any(item.lower().endswith(ext) for ext in valid_extensions):
- image_files.append(path)
- elif os.path.isdir(path):
- # Support kohya-ss/sd-scripts folder structure
- repeat = 1
- if item.split("_")[0].isdigit():
- repeat = int(item.split("_")[0])
- image_files.extend([
- os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
- ] * repeat)
-
- caption_file_path = [
- f.replace(os.path.splitext(f)[1], ".txt")
- for f in image_files
- ]
- captions = []
- for caption_file in caption_file_path:
- caption_path = os.path.join(sub_input_dir, caption_file)
- if os.path.exists(caption_path):
- with open(caption_path, "r", encoding="utf-8") as f:
- caption = f.read().strip()
- captions.append(caption)
- else:
- captions.append("")
-
- width = width if width != -1 else None
- height = height if height != -1 else None
- output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
-
- logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
-
- logging.info(f"Encoding captions from {sub_input_dir}.")
- conditions = []
- empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
- for text in captions:
- if text == "":
- conditions.append(empty_cond)
- tokens = clip.tokenize(text)
- conditions.extend(clip.encode_from_tokens_scheduled(tokens))
- logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
- return (output_tensor, conditions)
-
-
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
@@ -379,10 +245,14 @@ def draw_loss_graph(loss_map, steps):
return img
-def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
+def find_all_highest_child_module_with_forward(
+ model: torch.nn.Module, result=None, name=None
+):
if result is None:
result = []
- elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
+ elif hasattr(model, "forward") and not isinstance(
+ model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
+ ):
result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
@@ -396,12 +266,13 @@ def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
+
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
+
def checkpointing_fwd(*args, **kwargs):
- return torch.utils.checkpoint.checkpoint(
- fwd, args, kwargs, use_reentrant=False
- )
+ return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
+
m.org_forward = org_forward
m.forward = checkpointing_fwd
@@ -412,130 +283,126 @@ def unpatch(m):
del m.org_forward
-class TrainLoraNode:
+class TrainLoraNode(io.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
- "latents": (
- "LATENT",
- {
- "tooltip": "The Latents to use for training, serve as dataset/input of the model."
- },
+ def define_schema(cls):
+ return io.Schema(
+ node_id="TrainLoraNode",
+ display_name="Train LoRA",
+ category="training",
+ is_experimental=True,
+ is_input_list=True, # All inputs become lists
+ inputs=[
+ io.Model.Input("model", tooltip="The model to train the LoRA on."),
+ io.Latent.Input(
+ "latents",
+ tooltip="The Latents to use for training, serve as dataset/input of the model.",
),
- "positive": (
- IO.CONDITIONING,
- {"tooltip": "The positive conditioning to use for training."},
+ io.Conditioning.Input(
+ "positive", tooltip="The positive conditioning to use for training."
),
- "batch_size": (
- IO.INT,
- {
- "default": 1,
- "min": 1,
- "max": 10000,
- "step": 1,
- "tooltip": "The batch size to use for training.",
- },
+ io.Int.Input(
+ "batch_size",
+ default=1,
+ min=1,
+ max=10000,
+ tooltip="The batch size to use for training.",
),
- "grad_accumulation_steps": (
- IO.INT,
- {
- "default": 1,
- "min": 1,
- "max": 1024,
- "step": 1,
- "tooltip": "The number of gradient accumulation steps to use for training.",
- }
+ io.Int.Input(
+ "grad_accumulation_steps",
+ default=1,
+ min=1,
+ max=1024,
+ tooltip="The number of gradient accumulation steps to use for training.",
),
- "steps": (
- IO.INT,
- {
- "default": 16,
- "min": 1,
- "max": 100000,
- "tooltip": "The number of steps to train the LoRA for.",
- },
+ io.Int.Input(
+ "steps",
+ default=16,
+ min=1,
+ max=100000,
+ tooltip="The number of steps to train the LoRA for.",
),
- "learning_rate": (
- IO.FLOAT,
- {
- "default": 0.0005,
- "min": 0.0000001,
- "max": 1.0,
- "step": 0.000001,
- "tooltip": "The learning rate to use for training.",
- },
+ io.Float.Input(
+ "learning_rate",
+ default=0.0005,
+ min=0.0000001,
+ max=1.0,
+ step=0.0000001,
+ tooltip="The learning rate to use for training.",
),
- "rank": (
- IO.INT,
- {
- "default": 8,
- "min": 1,
- "max": 128,
- "tooltip": "The rank of the LoRA layers.",
- },
+ io.Int.Input(
+ "rank",
+ default=8,
+ min=1,
+ max=128,
+ tooltip="The rank of the LoRA layers.",
),
- "optimizer": (
- ["AdamW", "Adam", "SGD", "RMSprop"],
- {
- "default": "AdamW",
- "tooltip": "The optimizer to use for training.",
- },
+ io.Combo.Input(
+ "optimizer",
+ options=["AdamW", "Adam", "SGD", "RMSprop"],
+ default="AdamW",
+ tooltip="The optimizer to use for training.",
),
- "loss_function": (
- ["MSE", "L1", "Huber", "SmoothL1"],
- {
- "default": "MSE",
- "tooltip": "The loss function to use for training.",
- },
+ io.Combo.Input(
+ "loss_function",
+ options=["MSE", "L1", "Huber", "SmoothL1"],
+ default="MSE",
+ tooltip="The loss function to use for training.",
),
- "seed": (
- IO.INT,
- {
- "default": 0,
- "min": 0,
- "max": 0xFFFFFFFFFFFFFFFF,
- "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
- },
+ io.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=0xFFFFFFFFFFFFFFFF,
+ tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
),
- "training_dtype": (
- ["bf16", "fp32"],
- {"default": "bf16", "tooltip": "The dtype to use for training."},
+ io.Combo.Input(
+ "training_dtype",
+ options=["bf16", "fp32"],
+ default="bf16",
+ tooltip="The dtype to use for training.",
),
- "lora_dtype": (
- ["bf16", "fp32"],
- {"default": "bf16", "tooltip": "The dtype to use for lora."},
+ io.Combo.Input(
+ "lora_dtype",
+ options=["bf16", "fp32"],
+ default="bf16",
+ tooltip="The dtype to use for lora.",
),
- "algorithm": (
- list(adapter_maps.keys()),
- {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
+ io.Combo.Input(
+ "algorithm",
+ options=list(adapter_maps.keys()),
+ default=list(adapter_maps.keys())[0],
+ tooltip="The algorithm to use for training.",
),
- "gradient_checkpointing": (
- IO.BOOLEAN,
- {
- "default": True,
- "tooltip": "Use gradient checkpointing for training.",
- }
+ io.Boolean.Input(
+ "gradient_checkpointing",
+ default=True,
+ tooltip="Use gradient checkpointing for training.",
),
- "existing_lora": (
- folder_paths.get_filename_list("loras") + ["[None]"],
- {
- "default": "[None]",
- "tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
- },
+ io.Combo.Input(
+ "existing_lora",
+ options=folder_paths.get_filename_list("loras") + ["[None]"],
+ default="[None]",
+ tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
- },
- }
+ ],
+ outputs=[
+ io.Model.Output(
+ display_name="model", tooltip="Model with LoRA applied"
+ ),
+ io.Custom("LORA_MODEL").Output(
+ display_name="lora", tooltip="LoRA weights"
+ ),
+ io.Custom("LOSS_MAP").Output(
+ display_name="loss_map", tooltip="Loss history"
+ ),
+ io.Int.Output(display_name="steps", tooltip="Total training steps"),
+ ],
+ )
- RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
- RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
- FUNCTION = "train"
- CATEGORY = "training"
- EXPERIMENTAL = True
-
- def train(
- self,
+ @classmethod
+ def execute(
+ cls,
model,
latents,
positive,
@@ -553,13 +420,74 @@ class TrainLoraNode:
gradient_checkpointing,
existing_lora,
):
+ # Extract scalars from lists (due to is_input_list=True)
+ model = model[0]
+ batch_size = batch_size[0]
+ steps = steps[0]
+ grad_accumulation_steps = grad_accumulation_steps[0]
+ learning_rate = learning_rate[0]
+ rank = rank[0]
+ optimizer = optimizer[0]
+ loss_function = loss_function[0]
+ seed = seed[0]
+ training_dtype = training_dtype[0]
+ lora_dtype = lora_dtype[0]
+ algorithm = algorithm[0]
+ gradient_checkpointing = gradient_checkpointing[0]
+ existing_lora = existing_lora[0]
+
+ # Handle latents - either single dict or list of dicts
+ if len(latents) == 1:
+ latents = latents[0]["samples"] # Single latent dict
+ else:
+ latent_list = []
+ for latent in latents:
+ latent = latent["samples"]
+ bs = latent.shape[0]
+ if bs != 1:
+ for sub_latent in latent:
+ latent_list.append(sub_latent[None])
+ else:
+ latent_list.append(latent)
+ latents = latent_list
+
+ # Handle conditioning - either single list or list of lists
+ if len(positive) == 1:
+ positive = positive[0] # Single conditioning list
+ else:
+ # Multiple conditioning lists - flatten
+ flat_positive = []
+ for cond in positive:
+ if isinstance(cond, list):
+ flat_positive.extend(cond)
+ else:
+ flat_positive.append(cond)
+ positive = flat_positive
+
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
- latents = latents["samples"].to(dtype)
- num_images = latents.shape[0]
+ # latents here can be list of different size latent or one large batch
+ if isinstance(latents, list):
+ all_shapes = set()
+ latents = [t.to(dtype) for t in latents]
+ for latent in latents:
+ all_shapes.add(latent.shape)
+ logging.info(f"Latent shapes: {all_shapes}")
+ if len(all_shapes) > 1:
+ multi_res = True
+ else:
+ multi_res = False
+ latents = torch.cat(latents, dim=0)
+ num_images = len(latents)
+ elif isinstance(latents, torch.Tensor):
+ latents = latents.to(dtype)
+ num_images = latents.shape[0]
+ else:
+ logging.error(f"Invalid latents type: {type(latents)}")
+
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
@@ -591,9 +519,7 @@ class TrainLoraNode:
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
- dora_scale = existing_weights.get(
- f"{key}.dora_scale", None
- )
+ dora_scale = existing_weights.get(f"{key}.dora_scale", None)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
@@ -605,7 +531,9 @@ class TrainLoraNode:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
- train_adapter = existing_adapter.to_train().to(lora_dtype)
+ train_adapter = existing_adapter.to_train().to(
+ lora_dtype
+ )
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
@@ -629,7 +557,9 @@ class TrainLoraNode:
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
- torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
+ torch.zeros(
+ m.bias.shape, dtype=lora_dtype, requires_grad=True
+ )
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
@@ -657,24 +587,31 @@ class TrainLoraNode:
# setup models
if gradient_checkpointing:
- for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
+ for m in find_all_highest_child_module_with_forward(
+ mp.model.diffusion_model
+ ):
patch(m)
mp.model.requires_grad_(False)
- comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
+ comfy.model_management.load_models_gpu(
+ [mp], memory_required=1e20, force_full_load=True
+ )
# Setup sampler and guider like in test script
loss_map = {"loss": []}
+
def loss_callback(loss):
loss_map["loss"].append(loss)
+
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
- total_steps=steps*grad_accumulation_steps,
+ total_steps=steps * grad_accumulation_steps,
seed=seed,
- training_dtype=dtype
+ training_dtype=dtype,
+ real_dataset=latents if multi_res else None,
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
@@ -684,12 +621,15 @@ class TrainLoraNode:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
+ if multi_res:
+ # use first latent as dummy latent if multi_res
+ latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
- seed=noise.seed
+ seed=noise.seed,
)
finally:
for m in mp.model.modules():
@@ -702,111 +642,118 @@ class TrainLoraNode:
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
- return (mp, lora_sd, loss_map, steps + existing_steps)
+ return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
-class LoraModelLoader:
- def __init__(self):
- self.loaded_lora = None
+class LoraModelLoader(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="LoraModelLoader",
+ display_name="Load LoRA Model",
+ category="loaders",
+ is_experimental=True,
+ inputs=[
+ io.Model.Input(
+ "model", tooltip="The diffusion model the LoRA will be applied to."
+ ),
+ io.Custom("LORA_MODEL").Input(
+ "lora", tooltip="The LoRA model to apply to the diffusion model."
+ ),
+ io.Float.Input(
+ "strength_model",
+ default=1.0,
+ min=-100.0,
+ max=100.0,
+ tooltip="How strongly to modify the diffusion model. This value can be negative.",
+ ),
+ ],
+ outputs=[
+ io.Model.Output(
+ display_name="model", tooltip="The modified diffusion model."
+ ),
+ ],
+ )
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
- "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
- "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
- }
- }
-
- RETURN_TYPES = ("MODEL",)
- OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
- FUNCTION = "load_lora_model"
-
- CATEGORY = "loaders"
- DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
- EXPERIMENTAL = True
-
- def load_lora_model(self, model, lora, strength_model):
+ def execute(cls, model, lora, strength_model):
if strength_model == 0:
- return (model, )
+ return io.NodeOutput(model)
- model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
- return (model_lora, )
+ model_lora, _ = comfy.sd.load_lora_for_models(
+ model, None, lora, strength_model, 0
+ )
+ return io.NodeOutput(model_lora)
-class SaveLoRA:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
+class SaveLoRA(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SaveLoRA",
+ display_name="Save LoRA Weights",
+ category="loaders",
+ is_experimental=True,
+ is_output_node=True,
+ inputs=[
+ io.Custom("LORA_MODEL").Input(
+ "lora",
+ tooltip="The LoRA model to save. Do not use the model with LoRA layers.",
+ ),
+ io.String.Input(
+ "prefix",
+ default="loras/ComfyUI_trained_lora",
+ tooltip="The prefix to use for the saved LoRA file.",
+ ),
+ io.Int.Input(
+ "steps",
+ optional=True,
+ tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
+ ),
+ ],
+ outputs=[],
+ )
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "lora": (
- IO.LORA_MODEL,
- {
- "tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
- },
- ),
- "prefix": (
- "STRING",
- {
- "default": "loras/ComfyUI_trained_lora",
- "tooltip": "The prefix to use for the saved LoRA file.",
- },
- ),
- },
- "optional": {
- "steps": (
- IO.INT,
- {
- "forceInput": True,
- "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
- },
- ),
- },
- }
-
- RETURN_TYPES = ()
- FUNCTION = "save"
- CATEGORY = "loaders"
- EXPERIMENTAL = True
- OUTPUT_NODE = True
-
- def save(self, lora, prefix, steps=None):
- full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
+ def execute(cls, lora, prefix, steps=None):
+ output_dir = folder_paths.get_output_directory()
+ full_output_folder, filename, counter, subfolder, filename_prefix = (
+ folder_paths.get_save_image_path(prefix, output_dir)
+ )
if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
- return {}
+ return io.NodeOutput()
-class LossGraphNode:
- def __init__(self):
- self.output_dir = folder_paths.get_temp_directory()
+class LossGraphNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="LossGraphNode",
+ display_name="Plot Loss Graph",
+ category="training",
+ is_experimental=True,
+ is_output_node=True,
+ inputs=[
+ io.Custom("LOSS_MAP").Input(
+ "loss", tooltip="Loss map from training node."
+ ),
+ io.String.Input(
+ "filename_prefix",
+ default="loss_graph",
+ tooltip="Prefix for the saved loss graph image.",
+ ),
+ ],
+ outputs=[],
+ hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
+ )
@classmethod
- def INPUT_TYPES(s):
- return {
- "required": {
- "loss": (IO.LOSS_MAP, {"default": {}}),
- "filename_prefix": (IO.STRING, {"default": "loss_graph"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
-
- RETURN_TYPES = ()
- FUNCTION = "plot_loss"
- OUTPUT_NODE = True
- CATEGORY = "training"
- EXPERIMENTAL = True
- DESCRIPTION = "Plots the loss graph and saves it to the output directory."
-
- def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
+ def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
@@ -849,47 +796,27 @@ class LossGraphNode:
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
- metadata = None
- if not args.disable_metadata:
- metadata = PngInfo()
- if prompt is not None:
- metadata.add_text("prompt", json.dumps(prompt))
- if extra_pnginfo is not None:
- for x in extra_pnginfo:
- metadata.add_text(x, json.dumps(extra_pnginfo[x]))
+ # Convert PIL image to tensor for PreviewImage
+ img_array = np.array(img).astype(np.float32) / 255.0
+ img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3]
- date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- img.save(
- os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
- pnginfo=metadata,
- )
- return {
- "ui": {
- "images": [
- {
- "filename": f"{filename_prefix}_{date}.png",
- "subfolder": "",
- "type": "temp",
- }
- ]
- }
- }
+ # Return preview UI
+ return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls))
-NODE_CLASS_MAPPINGS = {
- "TrainLoraNode": TrainLoraNode,
- "SaveLoRANode": SaveLoRA,
- "LoraModelLoader": LoraModelLoader,
- "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
- "LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
- "LossGraphNode": LossGraphNode,
-}
+# ========== Extension Setup ==========
-NODE_DISPLAY_NAME_MAPPINGS = {
- "TrainLoraNode": "Train LoRA",
- "SaveLoRANode": "Save LoRA Weights",
- "LoraModelLoader": "Load LoRA Model",
- "LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
- "LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
- "LossGraphNode": "Plot Loss Graph",
-}
+
+class TrainingExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ TrainLoraNode,
+ LoraModelLoader,
+ SaveLoRA,
+ LossGraphNode,
+ ]
+
+
+async def comfy_entrypoint() -> TrainingExtension:
+ return TrainingExtension()
diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py
index 69fabb12e..c609e03da 100644
--- a/comfy_extras/nodes_video.py
+++ b/comfy_extras/nodes_video.py
@@ -8,10 +8,7 @@ import json
from typing import Optional
from typing_extensions import override
from fractions import Fraction
-from comfy_api.input import AudioInput, ImageInput, VideoInput
-from comfy_api.input_impl import VideoFromComponents, VideoFromFile
-from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
-from comfy_api.latest import ComfyExtension, io, ui
+from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
from comfy.cli_args import args
class SaveWEBM(io.ComfyNode):
@@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
- outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
- io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
- io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
+ io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
+ io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
- outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
- def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
+ def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
- file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
+ file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
- format=format,
+ format=Types.VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
@@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
)
@classmethod
- def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
+ def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
return io.NodeOutput(
- VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
+ InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
class GetVideoComponents(io.ComfyNode):
@@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
)
@classmethod
- def execute(cls, video: VideoInput) -> io.NodeOutput:
+ def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components()
-
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
+
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
@classmethod
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
- return io.NodeOutput(VideoFromFile(video_path))
+ return io.NodeOutput(InputImpl.VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):
diff --git a/comfyui_version.py b/comfyui_version.py
index fa4b4f4b0..4b039356e 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.3.75"
+__version__ = "0.3.76"
diff --git a/cuda_malloc.py b/cuda_malloc.py
index 6520d5123..ee2bc4b69 100644
--- a/cuda_malloc.py
+++ b/cuda_malloc.py
@@ -63,18 +63,22 @@ def cuda_malloc_supported():
return True
+version = ""
+
+try:
+ torch_spec = importlib.util.find_spec("torch")
+ for folder in torch_spec.submodule_search_locations:
+ ver_file = os.path.join(folder, "version.py")
+ if os.path.isfile(ver_file):
+ spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ version = module.__version__
+except:
+ pass
+
if not args.cuda_malloc:
try:
- version = ""
- torch_spec = importlib.util.find_spec("torch")
- for folder in torch_spec.submodule_search_locations:
- ver_file = os.path.join(folder, "version.py")
- if os.path.isfile(ver_file):
- spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- version = module.__version__
-
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
args.cuda_malloc = cuda_malloc_supported()
@@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
+
+def get_torch_version_noimport():
+ return str(version)
diff --git a/execution.py b/execution.py
index 17c77beab..c2186ac98 100644
--- a/execution.py
+++ b/execution.py
@@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
-from comfy_api.latest import io
+from comfy_api.latest import io, _io
class ExecutionResult(Enum):
@@ -76,7 +76,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
- input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
+ input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed)
@@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
+ v3_data: io.V3Data = {}
if is_v3:
- valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
+ valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
- return input_data_all, missing_keys, hidden_inputs_v3
+ v3_data["hidden_inputs"] = hidden_inputs_v3
+ return input_data_all, missing_keys, v3_data
map_node_over_list = None #Don't hook this please
@@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
-async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
+async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
- class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
+ class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
- class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
+ class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
+ # in case of dynamic inputs, restructure inputs to expected nested dict
+ if v3_data is not None:
+ inputs = _io.build_nested_inputs(inputs, v3_data)
# V1
else:
f = getattr(obj, func)
@@ -320,8 +325,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
-async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
- return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
+async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
+ return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
- input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
+ input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
- required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
+ required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
- output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
+ output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
@@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated):
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
- class_inputs = obj_class.INPUT_TYPES()
- valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
-
errors = []
valid = True
validate_function_inputs = []
validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal):
+ class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name)
else:
+ class_inputs = obj_class.INPUT_TYPES()
validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None:
@@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
validate_has_kwargs = argspec.varkw is not None
received_types = {}
+ valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
+
for x in valid_inputs:
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
@@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
- input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
+ input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:
@@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
- ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
+ ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):
diff --git a/folder_paths.py b/folder_paths.py
index ffdc4d020..9c96540e3 100644
--- a/folder_paths.py
+++ b/folder_paths.py
@@ -137,6 +137,71 @@ def set_user_directory(user_dir: str) -> None:
user_directory = user_dir
+# System User Protection - Protects system directories from HTTP endpoint access
+# System Users are internal-only users that cannot be accessed via HTTP endpoints.
+# They use the '__' prefix convention (similar to Python's private member convention).
+SYSTEM_USER_PREFIX = "__"
+
+
+def get_system_user_directory(name: str = "system") -> str:
+ """
+ Get the path to a System User directory.
+
+ System User directories (prefixed with '__') are only accessible via internal API,
+ not through HTTP endpoints. Use this for storing system-internal data that
+ should not be exposed to users.
+
+ Args:
+ name: System user name (e.g., "system", "cache"). Must be alphanumeric
+ with underscores allowed, but cannot start with underscore.
+
+ Returns:
+ Absolute path to the system user directory.
+
+ Raises:
+ ValueError: If name is empty, invalid, or starts with underscore.
+
+ Example:
+ >>> get_system_user_directory("cache")
+ '/path/to/user/__cache'
+ """
+ if not name or not isinstance(name, str):
+ raise ValueError("System user name cannot be empty")
+ if not name.replace("_", "").isalnum():
+ raise ValueError(f"Invalid system user name: '{name}'")
+ if name.startswith("_"):
+ raise ValueError("System user name should not start with underscore")
+ return os.path.join(get_user_directory(), f"{SYSTEM_USER_PREFIX}{name}")
+
+
+def get_public_user_directory(user_id: str) -> str | None:
+ """
+ Get the path to a Public User directory for HTTP endpoint access.
+
+ This function provides structural security by returning None for any
+ System User (prefixed with '__'). All HTTP endpoints should use this
+ function instead of directly constructing user paths.
+
+ Args:
+ user_id: User identifier from HTTP request.
+
+ Returns:
+ Absolute path to the user directory, or None if user_id is invalid
+ or refers to a System User.
+
+ Example:
+ >>> get_public_user_directory("default")
+ '/path/to/user/default'
+ >>> get_public_user_directory("__system")
+ None
+ """
+ if not user_id or not isinstance(user_id, str):
+ return None
+ if user_id.startswith(SYSTEM_USER_PREFIX):
+ return None
+ return os.path.join(get_user_directory(), user_id)
+
+
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name: str) -> str | None:
if type_name == "output":
diff --git a/latent_preview.py b/latent_preview.py
index ddf6dcf49..66bded4b9 100644
--- a/latent_preview.py
+++ b/latent_preview.py
@@ -2,17 +2,24 @@ import torch
from PIL import Image
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
+from comfy.sd import VAE
import comfy.model_management
import folder_paths
import comfy.utils
import logging
MAX_PREVIEW_RESOLUTION = args.preview_size
+VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
-def preview_to_image(latent_image):
- latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
- .mul(0xFF) # to 0..255
- )
+def preview_to_image(latent_image, do_scale=True):
+ if do_scale:
+ latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
+ .mul(0xFF) # to 0..255
+ )
+ else:
+ latents_ubyte = (latent_image.clamp(0, 1)
+ .mul(0xFF) # to 0..255
+ )
if comfy.model_management.directml_enabled:
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
@@ -35,6 +42,10 @@ class TAESDPreviewerImpl(LatentPreviewer):
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
return preview_to_image(x_sample)
+class TAEHVPreviewerImpl(TAESDPreviewerImpl):
+ def decode_latent_to_preview(self, x0):
+ x_sample = self.taesd.decode(x0[:1, :, :1])[0][0]
+ return preview_to_image(x_sample, do_scale=False)
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
@@ -81,8 +92,13 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
- taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
- previewer = TAESDPreviewerImpl(taesd)
+ if latent_format.taesd_decoder_name in VIDEO_TAES:
+ taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path))
+ taesd.first_stage_model.show_progress_bar = False
+ previewer = TAEHVPreviewerImpl(taesd)
+ else:
+ taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
+ previewer = TAESDPreviewerImpl(taesd)
else:
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
diff --git a/main.py b/main.py
index 0cd815d9e..0d02a087b 100644
--- a/main.py
+++ b/main.py
@@ -167,6 +167,9 @@ if __name__ == "__main__":
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc
+ if "rocm" in cuda_malloc.get_torch_version_noimport():
+ os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
+
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
diff --git a/manager_requirements.txt b/manager_requirements.txt
index 52cc5389c..b95cefb74 100644
--- a/manager_requirements.txt
+++ b/manager_requirements.txt
@@ -1 +1 @@
-comfyui_manager==4.0.3b3
+comfyui_manager==4.0.3b4
diff --git a/nodes.py b/nodes.py
index 902468895..8d28a725d 100644
--- a/nodes.py
+++ b/nodes.py
@@ -695,8 +695,10 @@ class LoraLoaderModelOnly(LoraLoader):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
+ video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
+ image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@staticmethod
- def vae_list():
+ def vae_list(s):
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
@@ -725,6 +727,11 @@ class VAELoader:
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
+ else:
+ for tae in s.video_taes:
+ if v.startswith(tae):
+ vaes.append(v)
+
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
@@ -768,7 +775,7 @@ class VAELoader:
@classmethod
def INPUT_TYPES(s):
- return {"required": { "vae_name": (s.vae_list(), )}}
+ return {"required": { "vae_name": (s.vae_list(s), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@@ -779,10 +786,13 @@ class VAELoader:
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
- elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
+ elif vae_name in self.image_taes:
sd = self.load_taesd(vae_name)
else:
- vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
+ if os.path.splitext(vae_name)[0] in self.video_taes:
+ vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
+ else:
+ vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
@@ -932,7 +942,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -960,7 +970,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ),
+ "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -2287,6 +2297,7 @@ async def init_builtin_extra_nodes():
"nodes_images.py",
"nodes_video_model.py",
"nodes_train.py",
+ "nodes_dataset.py",
"nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",
@@ -2344,7 +2355,9 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
+ "nodes_logic.py",
"nodes_nop.py",
+ "nodes_kandinsky5.py",
]
import_failed = []
diff --git a/pyproject.toml b/pyproject.toml
index 9009e65fe..02b94a0ce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.3.75"
+version = "0.3.76"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
diff --git a/requirements.txt b/requirements.txt
index 5f20816d6..11a7ac245 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
-comfyui-frontend-package==1.30.6
-comfyui-workflow-templates==0.7.20
+comfyui-frontend-package==1.33.13
+comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1
torch
torchsde
diff --git a/server.py b/server.py
index 6b8d94f3f..ac4f42222 100644
--- a/server.py
+++ b/server.py
@@ -98,7 +98,7 @@ def create_cors_middleware(allowed_origin: str):
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
- response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
+ response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
@@ -177,7 +177,7 @@ def create_block_external_middleware():
else:
response = await handler(request)
- response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
+ response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
return response
return block_external_middleware
diff --git a/tests-unit/app_test/user_manager_system_user_test.py b/tests-unit/app_test/user_manager_system_user_test.py
new file mode 100644
index 000000000..63b1ac5e5
--- /dev/null
+++ b/tests-unit/app_test/user_manager_system_user_test.py
@@ -0,0 +1,193 @@
+"""Tests for System User Protection in user_manager.py
+
+Tests cover:
+- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers
+- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory()
+- add_user(): 3rd defense layer - prevents creation of System User names
+- Defense layers integration tests
+"""
+
+import pytest
+from unittest.mock import MagicMock, patch
+import tempfile
+
+import folder_paths
+from app.user_manager import UserManager
+
+
+@pytest.fixture
+def mock_user_directory():
+ """Create a temporary user directory."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ original_dir = folder_paths.get_user_directory()
+ folder_paths.set_user_directory(temp_dir)
+ yield temp_dir
+ folder_paths.set_user_directory(original_dir)
+
+
+@pytest.fixture
+def user_manager(mock_user_directory):
+ """Create a UserManager instance for testing."""
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ manager = UserManager()
+ # Add a default user for testing
+ manager.users = {"default": "default", "test_user_123": "Test User"}
+ yield manager
+
+
+@pytest.fixture
+def mock_request():
+ """Create a mock request object."""
+ request = MagicMock()
+ request.headers = {}
+ return request
+
+
+class TestGetRequestUserId:
+ """Tests for get_request_user_id() - 1st defense layer.
+
+ Verifies:
+ - System Users (__ prefix) in HTTP header are rejected with KeyError
+ - Public Users pass through successfully
+ """
+
+ def test_system_user_raises_error(self, user_manager, mock_request):
+ """Test System User in header raises KeyError."""
+ mock_request.headers = {"comfy-user": "__system"}
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ with pytest.raises(KeyError, match="Unknown user"):
+ user_manager.get_request_user_id(mock_request)
+
+ def test_system_user_cache_raises_error(self, user_manager, mock_request):
+ """Test System User cache raises KeyError."""
+ mock_request.headers = {"comfy-user": "__cache"}
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ with pytest.raises(KeyError, match="Unknown user"):
+ user_manager.get_request_user_id(mock_request)
+
+ def test_normal_user_works(self, user_manager, mock_request):
+ """Test normal user access works."""
+ mock_request.headers = {"comfy-user": "default"}
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ user_id = user_manager.get_request_user_id(mock_request)
+ assert user_id == "default"
+
+ def test_unknown_user_raises_error(self, user_manager, mock_request):
+ """Test unknown user raises KeyError."""
+ mock_request.headers = {"comfy-user": "unknown_user"}
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ with pytest.raises(KeyError, match="Unknown user"):
+ user_manager.get_request_user_id(mock_request)
+
+
+class TestGetRequestUserFilepath:
+ """Tests for get_request_user_filepath() - 2nd defense layer.
+
+ Verifies:
+ - Returns None when get_public_user_directory() returns None (System User)
+ - Acts as backup defense if 1st layer is bypassed
+ """
+
+ def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory):
+ """Test System User returns None (structural blocking)."""
+ # First, we need to mock get_request_user_id to return System User
+ # But actually, get_request_user_id will raise KeyError first
+ # So we test via get_public_user_directory returning None
+ mock_request.headers = {"comfy-user": "default"}
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ # Patch get_public_user_directory to return None for testing
+ with patch.object(folder_paths, 'get_public_user_directory', return_value=None):
+ result = user_manager.get_request_user_filepath(mock_request, "test.txt")
+ assert result is None
+
+ def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory):
+ """Test normal user gets valid filepath."""
+ mock_request.headers = {"comfy-user": "default"}
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ path = user_manager.get_request_user_filepath(mock_request, "test.txt")
+ assert path is not None
+ assert "default" in path
+ assert path.endswith("test.txt")
+
+
+class TestAddUser:
+ """Tests for add_user() - 3rd defense layer (creation-time blocking).
+
+ Verifies:
+ - System User name (__ prefix) creation is rejected with ValueError
+ - Sanitized usernames that become System User are also rejected
+ """
+
+ def test_system_user_prefix_name_raises(self, user_manager):
+ """Test System User prefix in name raises ValueError."""
+ with pytest.raises(ValueError, match="System User prefix not allowed"):
+ user_manager.add_user("__system")
+
+ def test_system_user_prefix_cache_raises(self, user_manager):
+ """Test System User cache prefix raises ValueError."""
+ with pytest.raises(ValueError, match="System User prefix not allowed"):
+ user_manager.add_user("__cache")
+
+ def test_sanitized_system_user_prefix_raises(self, user_manager):
+ """Test sanitized name becoming System User prefix raises ValueError (bypass prevention)."""
+ # "__test" directly starts with System User prefix
+ with pytest.raises(ValueError, match="System User prefix not allowed"):
+ user_manager.add_user("__test")
+
+ def test_normal_user_creation(self, user_manager, mock_user_directory):
+ """Test normal user creation works."""
+ user_id = user_manager.add_user("Normal User")
+ assert user_id is not None
+ assert not user_id.startswith("__")
+ assert "Normal-User" in user_id or "Normal_User" in user_id
+
+ def test_empty_name_raises(self, user_manager):
+ """Test empty name raises ValueError."""
+ with pytest.raises(ValueError, match="username not provided"):
+ user_manager.add_user("")
+
+ def test_whitespace_only_raises(self, user_manager):
+ """Test whitespace-only name raises ValueError."""
+ with pytest.raises(ValueError, match="username not provided"):
+ user_manager.add_user(" ")
+
+
+class TestDefenseLayers:
+ """Integration tests for all three defense layers.
+
+ Verifies:
+ - Each defense layer blocks System Users independently
+ - System User bypass is impossible through any layer
+ """
+
+ def test_layer1_get_request_user_id(self, user_manager, mock_request):
+ """Test 1st defense layer blocks System Users."""
+ mock_request.headers = {"comfy-user": "__system"}
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ with pytest.raises(KeyError):
+ user_manager.get_request_user_id(mock_request)
+
+ def test_layer2_get_public_user_directory(self):
+ """Test 2nd defense layer blocks System Users."""
+ result = folder_paths.get_public_user_directory("__system")
+ assert result is None
+
+ def test_layer3_add_user(self, user_manager):
+ """Test 3rd defense layer blocks System User creation."""
+ with pytest.raises(ValueError):
+ user_manager.add_user("__system")
diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py
index 63361309f..3a54941e6 100644
--- a/tests-unit/comfy_quant/test_mixed_precision.py
+++ b/tests-unit/comfy_quant/test_mixed_precision.py
@@ -2,6 +2,7 @@ import unittest
import torch
import sys
import os
+import json
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -15,6 +16,7 @@ if not has_gpu():
from comfy import ops
from comfy.quant_ops import QuantizedTensor
+import comfy.utils
class SimpleModel(torch.nn.Module):
@@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
}
+ state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
# Create model and load state dict (strict=False because custom loading pops keys)
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor
@@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
- output = model(input_tensor)
+ with torch.inference_mode():
+ output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
@@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict1, strict=False)
# Save state dict
@@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA)
@@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
+ state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
+
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
- model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
+ model = SimpleModel(operations=ops.mixed_precision_ops({}))
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
diff --git a/tests-unit/folder_paths_test/system_user_test.py b/tests-unit/folder_paths_test/system_user_test.py
new file mode 100644
index 000000000..cd46459f1
--- /dev/null
+++ b/tests-unit/folder_paths_test/system_user_test.py
@@ -0,0 +1,206 @@
+"""Tests for System User Protection in folder_paths.py
+
+Tests cover:
+- get_system_user_directory(): Internal API for custom nodes to access System User directories
+- get_public_user_directory(): HTTP endpoint access with System User blocking
+- Backward compatibility: Existing APIs unchanged
+- Security: Path traversal and injection prevention
+"""
+
+import pytest
+import os
+import tempfile
+
+from folder_paths import (
+ get_system_user_directory,
+ get_public_user_directory,
+ get_user_directory,
+ set_user_directory,
+)
+
+
+@pytest.fixture(scope="module")
+def mock_user_directory():
+ """Create a temporary user directory for testing."""
+ with tempfile.TemporaryDirectory() as temp_dir:
+ original_dir = get_user_directory()
+ set_user_directory(temp_dir)
+ yield temp_dir
+ set_user_directory(original_dir)
+
+
+class TestGetSystemUserDirectory:
+ """Tests for get_system_user_directory() - internal API for System User directories.
+
+ Verifies:
+ - Custom nodes can access System User directories via internal API
+ - Input validation prevents path traversal attacks
+ """
+
+ def test_default_name(self, mock_user_directory):
+ """Test default 'system' name."""
+ path = get_system_user_directory()
+ assert path.endswith("__system")
+ assert mock_user_directory in path
+
+ def test_custom_name(self, mock_user_directory):
+ """Test custom system user name."""
+ path = get_system_user_directory("cache")
+ assert path.endswith("__cache")
+ assert "__cache" in path
+
+ def test_name_with_underscore(self, mock_user_directory):
+ """Test name with underscore in middle."""
+ path = get_system_user_directory("my_cache")
+ assert "__my_cache" in path
+
+ def test_empty_name_raises(self):
+ """Test empty name raises ValueError."""
+ with pytest.raises(ValueError, match="cannot be empty"):
+ get_system_user_directory("")
+
+ def test_none_name_raises(self):
+ """Test None name raises ValueError."""
+ with pytest.raises(ValueError, match="cannot be empty"):
+ get_system_user_directory(None)
+
+ def test_name_starting_with_underscore_raises(self):
+ """Test name starting with underscore raises ValueError."""
+ with pytest.raises(ValueError, match="should not start with underscore"):
+ get_system_user_directory("_system")
+
+ def test_path_traversal_raises(self):
+ """Test path traversal attempt raises ValueError (security)."""
+ with pytest.raises(ValueError, match="Invalid system user name"):
+ get_system_user_directory("../escape")
+
+ def test_path_traversal_middle_raises(self):
+ """Test path traversal in middle raises ValueError (security)."""
+ with pytest.raises(ValueError, match="Invalid system user name"):
+ get_system_user_directory("system/../other")
+
+ def test_special_chars_raise(self):
+ """Test special characters raise ValueError (security)."""
+ with pytest.raises(ValueError, match="Invalid system user name"):
+ get_system_user_directory("system!")
+
+ def test_returns_absolute_path(self, mock_user_directory):
+ """Test returned path is absolute."""
+ path = get_system_user_directory("test")
+ assert os.path.isabs(path)
+
+
+class TestGetPublicUserDirectory:
+ """Tests for get_public_user_directory() - HTTP endpoint access with System User blocking.
+
+ Verifies:
+ - System Users (__ prefix) return None, blocking HTTP access
+ - Public Users get valid paths
+ - New endpoints using this function are automatically protected
+ """
+
+ def test_normal_user(self, mock_user_directory):
+ """Test normal user returns valid path."""
+ path = get_public_user_directory("default")
+ assert path is not None
+ assert "default" in path
+ assert mock_user_directory in path
+
+ def test_system_user_returns_none(self):
+ """Test System User (__ prefix) returns None - blocks HTTP access."""
+ assert get_public_user_directory("__system") is None
+
+ def test_system_user_cache_returns_none(self):
+ """Test System User cache returns None."""
+ assert get_public_user_directory("__cache") is None
+
+ def test_empty_user_returns_none(self):
+ """Test empty user returns None."""
+ assert get_public_user_directory("") is None
+
+ def test_none_user_returns_none(self):
+ """Test None user returns None."""
+ assert get_public_user_directory(None) is None
+
+ def test_header_injection_returns_none(self):
+ """Test header injection attempt returns None (security)."""
+ assert get_public_user_directory("__system\r\nX-Injected: true") is None
+
+ def test_null_byte_injection_returns_none(self):
+ """Test null byte injection handling (security)."""
+ # Note: startswith check happens before any path operations
+ result = get_public_user_directory("user\x00__system")
+ # This should return a path since it doesn't start with __
+ # The actual security comes from the path not being __*
+ assert result is not None or result is None # Depends on validation
+
+ def test_path_traversal_attempt(self, mock_user_directory):
+ """Test path traversal attempt handling."""
+ # This function doesn't validate paths, only reserved prefix
+ # Path traversal should be handled by the caller
+ path = get_public_user_directory("../../../etc/passwd")
+ # Returns path but doesn't start with __, so not None
+ # Actual path validation happens in user_manager
+ assert path is not None or "__" not in "../../../etc/passwd"
+
+ def test_returns_absolute_path(self, mock_user_directory):
+ """Test returned path is absolute."""
+ path = get_public_user_directory("testuser")
+ assert path is not None
+ assert os.path.isabs(path)
+
+
+class TestBackwardCompatibility:
+ """Tests for backward compatibility with existing APIs.
+
+ Verifies:
+ - get_user_directory() API unchanged
+ - Existing user data remains accessible
+ """
+
+ def test_get_user_directory_unchanged(self, mock_user_directory):
+ """Test get_user_directory() still works as before."""
+ user_dir = get_user_directory()
+ assert user_dir is not None
+ assert os.path.isabs(user_dir)
+ assert user_dir == mock_user_directory
+
+ def test_existing_user_accessible(self, mock_user_directory):
+ """Test existing users can access their directories."""
+ path = get_public_user_directory("default")
+ assert path is not None
+ assert "default" in path
+
+
+class TestEdgeCases:
+ """Tests for edge cases in System User detection.
+
+ Verifies:
+ - Only __ prefix is blocked (not _, not middle __)
+ - Bypass attempts are prevented
+ """
+
+ def test_prefix_only(self):
+ """Test prefix-only string is blocked."""
+ assert get_public_user_directory("__") is None
+
+ def test_single_underscore_allowed(self):
+ """Test single underscore prefix is allowed (not System User)."""
+ path = get_public_user_directory("_system")
+ assert path is not None
+ assert "_system" in path
+
+ def test_triple_underscore_blocked(self):
+ """Test triple underscore is blocked (starts with __)."""
+ assert get_public_user_directory("___system") is None
+
+ def test_underscore_in_middle_allowed(self):
+ """Test underscore in middle is allowed."""
+ path = get_public_user_directory("my__system")
+ assert path is not None
+ assert "my__system" in path
+
+ def test_leading_space_allowed(self):
+ """Test leading space + prefix is allowed (doesn't start with __)."""
+ path = get_public_user_directory(" __system")
+ assert path is not None
diff --git a/tests-unit/prompt_server_test/system_user_endpoint_test.py b/tests-unit/prompt_server_test/system_user_endpoint_test.py
new file mode 100644
index 000000000..22ac00af9
--- /dev/null
+++ b/tests-unit/prompt_server_test/system_user_endpoint_test.py
@@ -0,0 +1,375 @@
+"""E2E Tests for System User Protection HTTP Endpoints
+
+Tests cover:
+- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move)
+- User creation blocking: System User names cannot be created via POST /users
+- Backward compatibility: Public Users work as before
+- Custom node scenario: Internal API works while HTTP is blocked
+- Structural security: get_public_user_directory() provides automatic protection
+"""
+
+import pytest
+import os
+from aiohttp import web
+from app.user_manager import UserManager
+from unittest.mock import patch
+import folder_paths
+
+
+@pytest.fixture
+def mock_user_directory(tmp_path):
+ """Create a temporary user directory."""
+ original_dir = folder_paths.get_user_directory()
+ folder_paths.set_user_directory(str(tmp_path))
+ yield tmp_path
+ folder_paths.set_user_directory(original_dir)
+
+
+@pytest.fixture
+def user_manager_multi_user(mock_user_directory):
+ """Create UserManager in multi-user mode."""
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ um = UserManager()
+ # Add test users
+ um.users = {"default": "default", "test_user_123": "Test User"}
+ yield um
+
+
+@pytest.fixture
+def app_multi_user(user_manager_multi_user):
+ """Create app with multi-user mode enabled."""
+ app = web.Application()
+ routes = web.RouteTableDef()
+ user_manager_multi_user.add_routes(routes)
+ app.add_routes(routes)
+ return app
+
+
+class TestSystemUserEndpointBlocking:
+ """E2E tests for System User blocking on all HTTP endpoints.
+
+ Verifies:
+ - GET /userdata blocked for System Users
+ - POST /userdata blocked for System Users
+ - DELETE /userdata blocked for System Users
+ - POST /userdata/.../move/... blocked for System Users
+ """
+
+ @pytest.mark.asyncio
+ async def test_userdata_get_blocks_system_user(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ GET /userdata with System User header should be blocked.
+ """
+ # Create test directory for System User (simulating internal creation)
+ system_user_dir = mock_user_directory / "__system"
+ system_user_dir.mkdir()
+ (system_user_dir / "secret.txt").write_text("sensitive data")
+
+ client = await aiohttp_client(app_multi_user)
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ # Attempt to access System User's data via HTTP
+ resp = await client.get(
+ "/userdata?dir=.",
+ headers={"comfy-user": "__system"}
+ )
+
+ # Should be blocked (403 Forbidden or similar error)
+ assert resp.status in [400, 403, 500], \
+ f"System User access should be blocked, got {resp.status}"
+
+ @pytest.mark.asyncio
+ async def test_userdata_post_blocks_system_user(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ POST /userdata with System User header should be blocked.
+ """
+ client = await aiohttp_client(app_multi_user)
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ resp = await client.post(
+ "/userdata/test.txt",
+ headers={"comfy-user": "__system"},
+ data=b"malicious content"
+ )
+
+ assert resp.status in [400, 403, 500], \
+ f"System User write should be blocked, got {resp.status}"
+
+ # Verify no file was created
+ assert not (mock_user_directory / "__system" / "test.txt").exists()
+
+ @pytest.mark.asyncio
+ async def test_userdata_delete_blocks_system_user(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ DELETE /userdata with System User header should be blocked.
+ """
+ # Create a file in System User directory
+ system_user_dir = mock_user_directory / "__system"
+ system_user_dir.mkdir()
+ secret_file = system_user_dir / "secret.txt"
+ secret_file.write_text("do not delete")
+
+ client = await aiohttp_client(app_multi_user)
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ resp = await client.delete(
+ "/userdata/secret.txt",
+ headers={"comfy-user": "__system"}
+ )
+
+ assert resp.status in [400, 403, 500], \
+ f"System User delete should be blocked, got {resp.status}"
+
+ # Verify file still exists
+ assert secret_file.exists()
+
+ @pytest.mark.asyncio
+ async def test_v2_userdata_blocks_system_user(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ GET /v2/userdata with System User header should be blocked.
+ """
+ client = await aiohttp_client(app_multi_user)
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ resp = await client.get(
+ "/v2/userdata",
+ headers={"comfy-user": "__system"}
+ )
+
+ assert resp.status in [400, 403, 500], \
+ f"System User v2 access should be blocked, got {resp.status}"
+
+ @pytest.mark.asyncio
+ async def test_move_userdata_blocks_system_user(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ POST /userdata/{file}/move/{dest} with System User header should be blocked.
+ """
+ system_user_dir = mock_user_directory / "__system"
+ system_user_dir.mkdir()
+ (system_user_dir / "source.txt").write_text("sensitive data")
+
+ client = await aiohttp_client(app_multi_user)
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ resp = await client.post(
+ "/userdata/source.txt/move/dest.txt",
+ headers={"comfy-user": "__system"}
+ )
+
+ assert resp.status in [400, 403, 500], \
+ f"System User move should be blocked, got {resp.status}"
+
+ # Verify source file still exists (move was blocked)
+ assert (system_user_dir / "source.txt").exists()
+
+
+class TestSystemUserCreationBlocking:
+ """E2E tests for blocking System User name creation via POST /users.
+
+ Verifies:
+ - POST /users returns 400 for System User name (not 500)
+ """
+
+ @pytest.mark.asyncio
+ async def test_post_users_blocks_system_user_name(
+ self, aiohttp_client, app_multi_user
+ ):
+ """POST /users with System User name should return 400 Bad Request."""
+ client = await aiohttp_client(app_multi_user)
+
+ resp = await client.post(
+ "/users",
+ json={"username": "__system"}
+ )
+
+ assert resp.status == 400, \
+ f"System User creation should return 400, got {resp.status}"
+
+ @pytest.mark.asyncio
+ async def test_post_users_blocks_system_user_prefix_variations(
+ self, aiohttp_client, app_multi_user
+ ):
+ """POST /users with any System User prefix variation should return 400 Bad Request."""
+ client = await aiohttp_client(app_multi_user)
+
+ system_user_names = ["__system", "__cache", "__config", "__anything"]
+
+ for name in system_user_names:
+ resp = await client.post("/users", json={"username": name})
+ assert resp.status == 400, \
+ f"System User name '{name}' should return 400, got {resp.status}"
+
+
+class TestPublicUserStillWorks:
+ """E2E tests for backward compatibility - Public Users should work as before.
+
+ Verifies:
+ - Public Users can access their data via HTTP
+ - Public Users can create files via HTTP
+ """
+
+ @pytest.mark.asyncio
+ async def test_public_user_can_access_userdata(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ Public Users should still be able to access their data.
+ """
+ # Create test directory for Public User
+ user_dir = mock_user_directory / "default"
+ user_dir.mkdir()
+ test_dir = user_dir / "workflows"
+ test_dir.mkdir()
+ (test_dir / "test.json").write_text('{"test": true}')
+
+ client = await aiohttp_client(app_multi_user)
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ resp = await client.get(
+ "/userdata?dir=workflows",
+ headers={"comfy-user": "default"}
+ )
+
+ assert resp.status == 200
+ data = await resp.json()
+ assert "test.json" in data
+
+ @pytest.mark.asyncio
+ async def test_public_user_can_create_files(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ Public Users should still be able to create files.
+ """
+ # Create user directory
+ user_dir = mock_user_directory / "default"
+ user_dir.mkdir()
+
+ client = await aiohttp_client(app_multi_user)
+
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ resp = await client.post(
+ "/userdata/newfile.txt",
+ headers={"comfy-user": "default"},
+ data=b"user content"
+ )
+
+ assert resp.status == 200
+ assert (user_dir / "newfile.txt").exists()
+
+
+class TestCustomNodeScenario:
+ """Tests for custom node use case: internal API access vs HTTP blocking.
+
+ Verifies:
+ - Internal API (get_system_user_directory) works for custom nodes
+ - HTTP endpoint cannot access data created via internal API
+ """
+
+ def test_internal_api_can_access_system_user(self, mock_user_directory):
+ """
+ Internal API (get_system_user_directory) should work for custom nodes.
+ """
+ # Custom node uses internal API
+ system_path = folder_paths.get_system_user_directory("mynode_config")
+
+ assert system_path is not None
+ assert "__mynode_config" in system_path
+
+ # Can create and write to System User directory
+ os.makedirs(system_path, exist_ok=True)
+ config_file = os.path.join(system_path, "settings.json")
+ with open(config_file, "w") as f:
+ f.write('{"api_key": "secret"}')
+
+ assert os.path.exists(config_file)
+
+ @pytest.mark.asyncio
+ async def test_http_cannot_access_internal_data(
+ self, aiohttp_client, app_multi_user, mock_user_directory
+ ):
+ """
+ HTTP endpoint cannot access data created via internal API.
+ """
+ # Custom node creates data via internal API
+ system_path = folder_paths.get_system_user_directory("mynode_config")
+ os.makedirs(system_path, exist_ok=True)
+ with open(os.path.join(system_path, "secret.json"), "w") as f:
+ f.write('{"api_key": "secret"}')
+
+ client = await aiohttp_client(app_multi_user)
+
+ # Attacker tries to access via HTTP
+ with patch('app.user_manager.args') as mock_args:
+ mock_args.multi_user = True
+ resp = await client.get(
+ "/userdata/secret.json",
+ headers={"comfy-user": "__mynode_config"}
+ )
+
+ # Should be blocked
+ assert resp.status in [400, 403, 500]
+
+
+class TestStructuralSecurity:
+ """Tests for structural security pattern.
+
+ Verifies:
+ - get_public_user_directory() automatically blocks System Users
+ - New endpoints using this function are automatically protected
+ """
+
+ def test_get_public_user_directory_blocks_system_user(self):
+ """
+ Any code using get_public_user_directory() is automatically protected.
+ """
+ # This is the structural security - any new endpoint using this function
+ # will automatically block System Users
+ assert folder_paths.get_public_user_directory("__system") is None
+ assert folder_paths.get_public_user_directory("__cache") is None
+ assert folder_paths.get_public_user_directory("__anything") is None
+
+ # Public Users work
+ assert folder_paths.get_public_user_directory("default") is not None
+ assert folder_paths.get_public_user_directory("user123") is not None
+
+ def test_structural_security_pattern(self, mock_user_directory):
+ """
+ Demonstrate the structural security pattern for new endpoints.
+
+ Any new endpoint should follow this pattern:
+ 1. Get user from request
+ 2. Use get_public_user_directory() - automatically blocks System Users
+ 3. If None, return error
+ """
+ def new_endpoint_handler(user_id: str) -> str | None:
+ """Example of how new endpoints should be implemented."""
+ user_path = folder_paths.get_public_user_directory(user_id)
+ if user_path is None:
+ return None # Blocked
+ return user_path
+
+ # System Users are automatically blocked
+ assert new_endpoint_handler("__system") is None
+ assert new_endpoint_handler("__secret") is None
+
+ # Public Users work
+ assert new_endpoint_handler("default") is not None