mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
ca6da3d9fb
1
.github/workflows/test-ci.yml
vendored
1
.github/workflows/test-ci.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
|
||||
4
.github/workflows/test-execution.yml
vendored
4
.github/workflows/test-execution.yml
vendored
@ -2,9 +2,9 @@ name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
4
.github/workflows/test-launch.yml
vendored
4
.github/workflows/test-launch.yml
vendored
@ -2,9 +2,9 @@ name: Test server launches without errors
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@ -2,9 +2,9 @@ name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
branches: [ main, master, release/** ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
|
||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@ -6,6 +6,7 @@ on:
|
||||
- "pyproject.toml"
|
||||
branches:
|
||||
- master
|
||||
- release/**
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
|
||||
@ -124,6 +124,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||
- Minor versions will be used for releases off the master branch.
|
||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
@ -214,6 +217,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
@ -44,7 +44,7 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
@ -55,7 +55,7 @@ class ModelFileManager:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
||||
@ -2,6 +2,25 @@ import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -17,24 +16,7 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
@ -73,7 +55,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
|
||||
@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
@ -188,6 +188,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# Handle vace_context (temporal dim is 3)
|
||||
elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
vace_cond = cond_value.cond
|
||||
if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim):
|
||||
sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list)
|
||||
new_cond_item[cond_key] = cond_value._copy_with(sliced_vace)
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
|
||||
@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
else:
|
||||
break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
|
||||
@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
|
||||
def default_noise_sampler(x, seed=None):
|
||||
if seed is not None:
|
||||
if x.device == torch.device("cpu"):
|
||||
seed += 1
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
else:
|
||||
@ -1618,6 +1621,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
||||
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
||||
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||
@ -1765,7 +1779,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
# Predictor
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
x_pred = denoised
|
||||
else:
|
||||
tau_t = tau_func(sigmas[i + 1])
|
||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||
@ -1786,7 +1800,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
if tau_t > 0 and s_noise > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||
x_pred = x_pred + noise
|
||||
return x
|
||||
return x_pred
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
|
||||
@ -3,7 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management, model_patcher
|
||||
import model_management
|
||||
import model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
|
||||
@ -491,7 +491,8 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
@ -625,7 +626,7 @@ class NextDiT(nn.Module):
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
|
||||
@ -30,6 +30,13 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if (q.device.type != "cuda" or
|
||||
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||
mask is not None):
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
B, H, L, D = q.shape
|
||||
if H != heads:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
B, N, inner_dim = q.shape
|
||||
if inner_dim % heads != 0:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
dim_head = inner_dim // heads
|
||||
|
||||
if dim_head >= 256 or N <= 1024:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||
except Exception as e:
|
||||
exception_fallback = True
|
||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||
|
||||
if exception_fallback:
|
||||
if not skip_reshape:
|
||||
del q_s, k_s, v_s
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
if not skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
pass
|
||||
else:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
|
||||
return out
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
@ -650,6 +744,8 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
|
||||
@ -394,7 +394,8 @@ class Model(nn.Module):
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@ -548,7 +549,8 @@ class Encoder(nn.Module):
|
||||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
||||
@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
||||
@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
||||
operations=operations
|
||||
)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if self.use_additional_t_cond:
|
||||
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
@ -320,10 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
default_ref_method="index",
|
||||
image_model=None,
|
||||
final_layer=True,
|
||||
use_additional_t_cond=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -334,12 +345,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.default_ref_method = default_ref_method
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
use_additional_t_cond=use_additional_t_cond,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
@ -361,6 +374,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
if self.default_ref_method == "index_timestep_zero":
|
||||
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
||||
|
||||
if final_layer:
|
||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||
@ -370,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patch_size = self.patch_size
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||
t_len = t
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
|
||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
if t_len > 1:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||
else:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@ -398,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
ref_latents=None,
|
||||
additional_t_cond=None,
|
||||
transformer_options={},
|
||||
control=None,
|
||||
**kwargs
|
||||
@ -416,14 +438,19 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_method = kwargs.get("ref_latents_method", "index")
|
||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||
negative_ref_method = ref_method == "negative_index"
|
||||
timestep_zero = ref_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif negative_ref_method:
|
||||
index -= 1
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
else:
|
||||
index = 1
|
||||
h_offset = 0
|
||||
@ -453,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
patches = transformer_options.get("patches", {})
|
||||
@ -508,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||
|
||||
@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
||||
@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
input_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
output_channels=3,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
image_channels=3,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
@ -1110,7 +1110,7 @@ class Lumina2(BaseModel):
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
|
||||
@ -259,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "__x0__" in state_dict_keys: # x0 pred
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None:
|
||||
if ctd_weight is not None: # NewBie
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
@ -618,6 +619,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||
dit_config["use_additional_t_cond"] = True
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
|
||||
@ -26,6 +26,7 @@ import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -333,13 +334,15 @@ except:
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
@ -1016,8 +1019,8 @@ NUM_STREAMS = 0
|
||||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia
|
||||
if is_nvidia():
|
||||
# Enable by default on Nvidia and AMD
|
||||
if is_nvidia() or is_amd():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
@ -1123,6 +1126,16 @@ if not args.disable_pinned_memory:
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
@ -1155,6 +1168,9 @@ def pin_memory(tensor):
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
else:
|
||||
logging.warning("Pin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
@ -1183,6 +1199,9 @@ def unpin_memory(tensor):
|
||||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
else:
|
||||
logging.warning("Unpin error.")
|
||||
discard_cuda_async_error()
|
||||
|
||||
return False
|
||||
|
||||
@ -1523,6 +1542,10 @@ def soft_empty_cache(force=False):
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -720,7 +720,7 @@ class Sampler:
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
@ -984,9 +984,6 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
@ -1013,6 +1010,24 @@ class CFGGuider:
|
||||
else:
|
||||
latent_shapes = [latent_image.shape]
|
||||
|
||||
if denoise_mask is not None:
|
||||
if denoise_mask.is_nested:
|
||||
denoise_masks = denoise_mask.unbind()
|
||||
denoise_masks = denoise_masks[:len(latent_shapes)]
|
||||
else:
|
||||
denoise_masks = [denoise_mask]
|
||||
|
||||
for i in range(len(denoise_masks), len(latent_shapes)):
|
||||
denoise_masks.append(torch.ones(latent_shapes[i]))
|
||||
|
||||
for i in range(len(denoise_masks)):
|
||||
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
|
||||
|
||||
if len(denoise_masks) > 1:
|
||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||
else:
|
||||
denoise_mask = denoise_masks[0]
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
|
||||
56
comfy/sd.py
56
comfy/sd.py
@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -321,6 +323,7 @@ class VAE:
|
||||
self.latent_channels = 4
|
||||
self.latent_dim = 2
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
@ -435,6 +438,7 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
@ -546,7 +550,9 @@ class VAE:
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||
self.pad_channel_value = 1.0
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
@ -582,6 +588,7 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 8
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 4096
|
||||
self.downscale_ratio = 4096
|
||||
self.latent_dim = 2
|
||||
@ -690,17 +697,28 @@ class VAE:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
if not self.crop_input:
|
||||
return pixels
|
||||
if self.crop_input:
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
if pixels.shape[-1] > self.output_channels:
|
||||
pixels = pixels[..., :self.output_channels]
|
||||
elif pixels.shape[-1] < self.output_channels:
|
||||
if self.pad_channel_value is not None:
|
||||
if isinstance(self.pad_channel_value, str):
|
||||
mode = self.pad_channel_value
|
||||
value = None
|
||||
else:
|
||||
mode = "constant"
|
||||
value = self.pad_channel_value
|
||||
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
@ -992,6 +1010,7 @@ class CLIPType(Enum):
|
||||
OVIS = 21
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
NEWBIE = 24
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -1022,6 +1041,7 @@ class TEModel(Enum):
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
QWEN3_2B = 17
|
||||
JINA_CLIP_2 = 18
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1031,6 +1051,8 @@ def detect_te_model(sd):
|
||||
return TEModel.CLIP_H
|
||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_L
|
||||
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
@ -1191,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.QWEN3_2B:
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@ -1246,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
elif clip_type == CLIPType.NEWBIE:
|
||||
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
|
||||
clip_data_gemma = clip_data[0]
|
||||
clip_data_jina = clip_data[1]
|
||||
else:
|
||||
clip_data_gemma = clip_data[1]
|
||||
clip_data_jina = clip_data[0]
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
@ -513,6 +513,8 @@ class SDTokenizer:
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_key = embedding_key
|
||||
|
||||
self.disable_weights = disable_weights
|
||||
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
@ -547,7 +549,7 @@ class SDTokenizer:
|
||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||
|
||||
text = escape_important(text)
|
||||
if kwargs.get("disable_weights", False):
|
||||
if kwargs.get("disable_weights", self.disable_weights):
|
||||
parsed_weights = [(text, 1.0)]
|
||||
else:
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
@ -154,7 +154,8 @@ class TAEHV(nn.Module):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
|
||||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
|
||||
219
comfy/text_encoders/jina_clip_2.py
Normal file
219
comfy/text_encoders/jina_clip_2.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
|
||||
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
|
||||
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
|
||||
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
||||
|
||||
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
||||
@dataclass
|
||||
class XLMRobertaConfig:
|
||||
vocab_size: int = 250002
|
||||
type_vocab_size: int = 1
|
||||
hidden_size: int = 1024
|
||||
num_hidden_layers: int = 24
|
||||
num_attention_heads: int = 16
|
||||
rotary_emb_base: float = 20000.0
|
||||
intermediate_size: int = 4096
|
||||
hidden_act: str = "gelu"
|
||||
hidden_dropout_prob: float = 0.1
|
||||
attention_probs_dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-05
|
||||
bos_token_id: int = 0
|
||||
eos_token_id: int = 2
|
||||
pad_token_id: int = 1
|
||||
|
||||
class XLMRobertaEmbeddings(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
|
||||
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, input_ids=None, embeddings=None):
|
||||
if input_ids is not None and embeddings is None:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
|
||||
if embeddings is not None:
|
||||
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = embeddings + token_type_embeddings
|
||||
return embeddings
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, base, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self._cos_cached = emb.cos().to(dtype)
|
||||
self._sin_cached = emb.sin().to(dtype)
|
||||
|
||||
def forward(self, q, k):
|
||||
batch, seqlen, heads, head_dim = q.shape
|
||||
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
|
||||
|
||||
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||
|
||||
def rotate_half(x):
|
||||
size = x.shape[-1] // 2
|
||||
x1, x2 = x[..., :size], x[..., size:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
class MHA(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = embed_dim // config.num_attention_heads
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
|
||||
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
||||
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
qkv = self.Wqkv(x)
|
||||
batch_size, seq_len, _ = qkv.shape
|
||||
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(2)
|
||||
|
||||
q, k = self.rotary_emb(q, k)
|
||||
|
||||
# NHD -> HND
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
||||
return self.out_proj(out)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
||||
self.activation = F.gelu
|
||||
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
|
||||
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, mask=None, optimized_attention=None):
|
||||
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
|
||||
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class XLMRobertaEncoder(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
||||
return hidden_states
|
||||
|
||||
class XLMRobertaModel_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
|
||||
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
||||
x = self.emb_ln(x)
|
||||
x = self.emb_drop(x)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
sequence_output = self.encoder(x, attention_mask=mask)
|
||||
|
||||
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
|
||||
pooled_output = None
|
||||
if attention_mask is None:
|
||||
pooled_output = sequence_output.mean(dim=1)
|
||||
else:
|
||||
attention_mask = attention_mask.to(sequence_output.dtype)
|
||||
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Intermediate output is not yet implemented, use None for placeholder
|
||||
return sequence_output, None, pooled_output
|
||||
|
||||
class XLMRobertaModel(nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.config = XLMRobertaConfig(**config_dict)
|
||||
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.model.embeddings.word_embeddings = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
||||
@ -3,13 +3,11 @@ import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
import math
|
||||
import logging
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
import comfy.model_management
|
||||
from . import qwen_vl
|
||||
|
||||
@dataclass
|
||||
@ -177,7 +175,7 @@ class Gemma3_4B_Config:
|
||||
num_key_value_heads: int = 4
|
||||
max_position_embeddings: int = 131072
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta = [10000.0, 1000000.0]
|
||||
rope_theta = [1000000.0, 10000.0]
|
||||
transformer_type: str = "gemma3"
|
||||
head_dim = 256
|
||||
rms_norm_add = True
|
||||
@ -186,8 +184,8 @@ class Gemma3_4B_Config:
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [False, False, False, False, False, 1024]
|
||||
rope_scale = [1.0, 8.0]
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@ -370,7 +368,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
|
||||
if config.sliding_attention is not None:
|
||||
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
||||
else:
|
||||
self.sliding_attention = False
|
||||
@ -387,7 +385,12 @@ class TransformerBlockGemma2(nn.Module):
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
if x.shape[1] > self.sliding_attention:
|
||||
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask + sliding_mask
|
||||
else:
|
||||
attention_mask = sliding_mask
|
||||
freqs_cis = freqs_cis[1]
|
||||
else:
|
||||
freqs_cis = freqs_cis[0]
|
||||
|
||||
@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
|
||||
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
62
comfy/text_encoders/newbie.py
Normal file
62
comfy/text_encoders/newbie.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.lumina2
|
||||
|
||||
class NewBieTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class NewBieTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = {dtype, dtype_gemma}
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.gemma.set_clip_options(options)
|
||||
self.jina.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.gemma.reset_clip_options()
|
||||
self.jina.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_gemma = token_weight_pairs["gemma"]
|
||||
token_weight_pairs_jina = token_weight_pairs["jina"]
|
||||
|
||||
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
|
||||
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
|
||||
|
||||
return gemma_out, jina_pooled, gemma_extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma.load_sd(sd)
|
||||
else:
|
||||
return self.jina.load_sd(sd)
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class NewBieTEModel_(NewBieTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return NewBieTEModel_
|
||||
@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
|
||||
@ -10,7 +10,6 @@ from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
from PIL import Image
|
||||
|
||||
@ -26,11 +26,9 @@ if TYPE_CHECKING:
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from ._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
from ._util import MESH, VOXEL, SVG as _SVG
|
||||
|
||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
||||
|
||||
class FolderType(str, Enum):
|
||||
input = "input"
|
||||
@ -77,16 +75,6 @@ class NumberDisplay(str, Enum):
|
||||
slider = "slider"
|
||||
|
||||
|
||||
class _StringIOType(str):
|
||||
def __ne__(self, value: object) -> bool:
|
||||
if self == "*" or value == "*":
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return True
|
||||
a = frozenset(self.split(","))
|
||||
b = frozenset(value.split(","))
|
||||
return not (b.issubset(a) or a.issubset(b))
|
||||
|
||||
class _ComfyType(ABC):
|
||||
Type = Any
|
||||
io_type: str = None
|
||||
@ -126,8 +114,7 @@ def comfytype(io_type: str, **kwargs):
|
||||
new_cls.__module__ = cls.__module__
|
||||
new_cls.__doc__ = cls.__doc__
|
||||
# assign ComfyType attributes, if needed
|
||||
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details)
|
||||
new_cls.io_type = _StringIOType(io_type)
|
||||
new_cls.io_type = io_type
|
||||
if hasattr(new_cls, "Input") and new_cls.Input is not None:
|
||||
new_cls.Input.Parent = new_cls
|
||||
if hasattr(new_cls, "Output") and new_cls.Output is not None:
|
||||
@ -166,7 +153,7 @@ class Input(_IO_V3):
|
||||
'''
|
||||
Base class for a V3 Input.
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__()
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
@ -174,6 +161,7 @@ class Input(_IO_V3):
|
||||
self.tooltip = tooltip
|
||||
self.lazy = lazy
|
||||
self.extra_dict = extra_dict if extra_dict is not None else {}
|
||||
self.rawLink = raw_link
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
@ -181,10 +169,11 @@ class Input(_IO_V3):
|
||||
"optional": self.optional,
|
||||
"tooltip": self.tooltip,
|
||||
"lazy": self.lazy,
|
||||
"rawLink": self.rawLink,
|
||||
}) | prune_dict(self.extra_dict)
|
||||
|
||||
def get_io_type(self):
|
||||
return _StringIOType(self.io_type)
|
||||
return self.io_type
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self]
|
||||
@ -195,8 +184,8 @@ class WidgetInput(Input):
|
||||
'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: Any=None,
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
self.default = default
|
||||
self.socketless = socketless
|
||||
self.widget_type = widget_type
|
||||
@ -218,13 +207,14 @@ class Output(_IO_V3):
|
||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
self.id = id
|
||||
self.display_name = display_name
|
||||
self.display_name = display_name if display_name else id
|
||||
self.tooltip = tooltip
|
||||
self.is_output_list = is_output_list
|
||||
|
||||
def as_dict(self):
|
||||
display_name = self.display_name if self.display_name else self.id
|
||||
return prune_dict({
|
||||
"display_name": self.display_name,
|
||||
"display_name": display_name,
|
||||
"tooltip": self.tooltip,
|
||||
"is_output_list": self.is_output_list,
|
||||
})
|
||||
@ -252,8 +242,8 @@ class Boolean(ComfyTypeIO):
|
||||
'''Boolean input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: bool=None, label_on: str=None, label_off: str=None,
|
||||
socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.label_on = label_on
|
||||
self.label_off = label_off
|
||||
self.default: bool
|
||||
@ -272,8 +262,8 @@ class Int(ComfyTypeIO):
|
||||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
@ -298,8 +288,8 @@ class Float(ComfyTypeIO):
|
||||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.step = step
|
||||
@ -324,8 +314,8 @@ class String(ComfyTypeIO):
|
||||
'''String input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
|
||||
socketless: bool=None, force_input: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input)
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link)
|
||||
self.multiline = multiline
|
||||
self.placeholder = placeholder
|
||||
self.dynamic_prompts = dynamic_prompts
|
||||
@ -358,12 +348,14 @@ class Combo(ComfyTypeIO):
|
||||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
socketless: bool=None,
|
||||
extra_dict=None,
|
||||
raw_link: bool=None,
|
||||
):
|
||||
if isinstance(options, type) and issubclass(options, Enum):
|
||||
options = [v.value for v in options]
|
||||
if isinstance(default, Enum):
|
||||
default = default.value
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
self.multiselect = False
|
||||
self.options = options
|
||||
self.control_after_generate = control_after_generate
|
||||
@ -387,10 +379,6 @@ class Combo(ComfyTypeIO):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.options = options if options is not None else []
|
||||
|
||||
@property
|
||||
def io_type(self):
|
||||
return self.options
|
||||
|
||||
@comfytype(io_type="COMBO")
|
||||
class MultiCombo(ComfyTypeI):
|
||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||
@ -399,8 +387,8 @@ class MultiCombo(ComfyTypeI):
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
socketless: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless)
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link)
|
||||
self.multiselect = True
|
||||
self.placeholder = placeholder
|
||||
self.chip = chip
|
||||
@ -433,9 +421,9 @@ class Webcam(ComfyTypeIO):
|
||||
Type = str
|
||||
def __init__(
|
||||
self, id: str, display_name: str=None, optional=False,
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None
|
||||
tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None
|
||||
):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link)
|
||||
|
||||
|
||||
@comfytype(io_type="MASK")
|
||||
@ -656,7 +644,7 @@ class Video(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="SVG")
|
||||
class SVG(ComfyTypeIO):
|
||||
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3
|
||||
Type = _SVG
|
||||
|
||||
@comfytype(io_type="LORA_MODEL")
|
||||
class LoraModel(ComfyTypeIO):
|
||||
@ -788,7 +776,7 @@ class MultiType:
|
||||
'''
|
||||
Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values.
|
||||
'''
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
# if id is an Input, then use that Input with overridden values
|
||||
self.input_override = None
|
||||
if isinstance(id, Input):
|
||||
@ -801,7 +789,7 @@ class MultiType:
|
||||
# if is a widget input, make sure widget_type is set appropriately
|
||||
if isinstance(self.input_override, WidgetInput):
|
||||
self.input_override.widget_type = self.input_override.get_io_type()
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
self._io_types = types
|
||||
|
||||
@property
|
||||
@ -855,8 +843,8 @@ class MatchType(ComfyTypeIO):
|
||||
|
||||
class Input(Input):
|
||||
def __init__(self, id: str, template: MatchType.Template,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
@ -867,6 +855,8 @@ class MatchType(ComfyTypeIO):
|
||||
class Output(Output):
|
||||
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
if not id and not display_name:
|
||||
display_name = "MATCHTYPE"
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.template = template
|
||||
|
||||
@ -879,24 +869,30 @@ class DynamicInput(Input, ABC):
|
||||
'''
|
||||
Abstract class for dynamic input registration.
|
||||
'''
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return []
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class DynamicOutput(Output, ABC):
|
||||
'''
|
||||
Abstract class for dynamic output registration.
|
||||
'''
|
||||
def __init__(self, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
pass
|
||||
|
||||
def get_dynamic(self) -> list[Output]:
|
||||
return []
|
||||
|
||||
def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]:
|
||||
if prefix_list is None:
|
||||
prefix_list = []
|
||||
if id is not None:
|
||||
prefix_list = prefix_list + [id]
|
||||
return prefix_list
|
||||
|
||||
def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str:
|
||||
assert not (prefix_list is None and id is None)
|
||||
if prefix_list is None:
|
||||
return id
|
||||
elif id is not None:
|
||||
prefix_list = prefix_list + [id]
|
||||
return ".".join(prefix_list)
|
||||
|
||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||
class Autogrow(ComfyTypeI):
|
||||
@ -933,14 +929,6 @@ class Autogrow(ComfyTypeI):
|
||||
def validate(self):
|
||||
self.input.validate()
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
real_inputs = []
|
||||
for name, input in self.cached_inputs.items():
|
||||
if name in live_inputs:
|
||||
real_inputs.append(input)
|
||||
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
||||
|
||||
class TemplatePrefix(_AutogrowTemplate):
|
||||
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||
super().__init__(input)
|
||||
@ -985,22 +973,45 @@ class Autogrow(ComfyTypeI):
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return self.template.get_all()
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + self.template.get_all()
|
||||
|
||||
def validate(self):
|
||||
self.template.validate()
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
||||
for inner_dict in d.values():
|
||||
if self.id in inner_dict:
|
||||
del inner_dict[self.id]
|
||||
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
# NOTE: purposely do not include self in out_dict; instead use only the template inputs
|
||||
# need to figure out names based on template type
|
||||
is_names = ("names" in value[1]["template"])
|
||||
is_prefix = ("prefix" in value[1]["template"])
|
||||
input = value[1]["template"]["input"]
|
||||
if is_names:
|
||||
min = value[1]["template"]["min"]
|
||||
names = value[1]["template"]["names"]
|
||||
max = len(names)
|
||||
elif is_prefix:
|
||||
prefix = value[1]["template"]["prefix"]
|
||||
min = value[1]["template"]["min"]
|
||||
max = value[1]["template"]["max"]
|
||||
names = [f"{prefix}{i}" for i in range(max)]
|
||||
# need to create a new input based on the contents of input
|
||||
template_input = None
|
||||
for _, dict_input in input.items():
|
||||
# for now, get just the first value from dict_input
|
||||
template_input = list(dict_input.values())[0]
|
||||
new_dict = {}
|
||||
for i, name in enumerate(names):
|
||||
expected_id = finalize_prefix(curr_prefix, name)
|
||||
if expected_id in live_inputs:
|
||||
# required
|
||||
if i < min:
|
||||
type_dict = new_dict.setdefault("required", {})
|
||||
# optional
|
||||
else:
|
||||
type_dict = new_dict.setdefault("optional", {})
|
||||
type_dict[name] = template_input
|
||||
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||
class DynamicCombo(ComfyTypeI):
|
||||
@ -1023,23 +1034,6 @@ class DynamicCombo(ComfyTypeI):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.options = options
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
# check if dynamic input's id is in live_inputs
|
||||
if self.id in live_inputs:
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
key = live_inputs[self.id]
|
||||
selected_option = None
|
||||
for option in self.options:
|
||||
if option.key == key:
|
||||
selected_option = option
|
||||
break
|
||||
if selected_option is not None:
|
||||
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return [input for option in self.options for input in option.inputs]
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + [input for option in self.options for input in option.inputs]
|
||||
|
||||
@ -1054,6 +1048,24 @@ class DynamicCombo(ComfyTypeI):
|
||||
for input in option.inputs:
|
||||
input.validate()
|
||||
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
finalized_id = finalize_prefix(curr_prefix)
|
||||
if finalized_id in live_inputs:
|
||||
key = live_inputs[finalized_id]
|
||||
selected_option = None
|
||||
# get options from dict
|
||||
options: list[dict[str, str | dict[str, Any]]] = value[1]["options"]
|
||||
for option in options:
|
||||
if option["key"] == key:
|
||||
selected_option = option
|
||||
break
|
||||
if selected_option is not None:
|
||||
parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix)
|
||||
# add self to inputs
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||
class DynamicSlot(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
@ -1076,17 +1088,8 @@ class DynamicSlot(ComfyTypeI):
|
||||
self.force_input = True
|
||||
self.slot.force_input = True
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
if self.id in live_inputs:
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return [self.slot] + self.inputs
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + [self.slot] + self.inputs
|
||||
return [self.slot] + self.inputs
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
@ -1100,17 +1103,41 @@ class DynamicSlot(ComfyTypeI):
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
|
||||
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
||||
dynamic = d.setdefault("dynamic_paths", {})
|
||||
if self is not None:
|
||||
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
||||
for i in inputs:
|
||||
if not isinstance(i, DynamicInput):
|
||||
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
||||
@staticmethod
|
||||
def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None):
|
||||
finalized_id = finalize_prefix(curr_prefix)
|
||||
if finalized_id in live_inputs:
|
||||
inputs = value[1]["inputs"]
|
||||
parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix)
|
||||
# add self to inputs
|
||||
out_dict[input_type][finalized_id] = value
|
||||
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
|
||||
def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]:
|
||||
return DYNAMIC_INPUT_LOOKUP[io_type]
|
||||
|
||||
def setup_dynamic_input_funcs():
|
||||
# Autogrow.Input
|
||||
register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic)
|
||||
# DynamicCombo.Input
|
||||
register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic)
|
||||
# DynamicSlot.Input
|
||||
register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic)
|
||||
|
||||
if len(DYNAMIC_INPUT_LOOKUP) == 0:
|
||||
setup_dynamic_input_funcs()
|
||||
|
||||
class V3Data(TypedDict):
|
||||
hidden_inputs: dict[str, Any]
|
||||
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||
dynamic_paths: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
@ -1146,6 +1173,10 @@ class HiddenHolder:
|
||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder:
|
||||
return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||
|
||||
class Hidden(str, Enum):
|
||||
'''
|
||||
Enumerator for requesting hidden variables in nodes.
|
||||
@ -1251,61 +1282,56 @@ class Schema:
|
||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||
'''
|
||||
nested_inputs: list[Input] = []
|
||||
if self.inputs is not None:
|
||||
for input in self.inputs:
|
||||
for input in self.inputs:
|
||||
if not isinstance(input, DynamicInput):
|
||||
nested_inputs.extend(input.get_all())
|
||||
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
||||
input_ids = [i.id for i in nested_inputs]
|
||||
output_ids = [o.id for o in self.outputs]
|
||||
input_set = set(input_ids)
|
||||
output_set = set(output_ids)
|
||||
issues = []
|
||||
issues: list[str] = []
|
||||
# verify ids are unique per list
|
||||
if len(input_set) != len(input_ids):
|
||||
issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.")
|
||||
if len(output_set) != len(output_ids):
|
||||
issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.")
|
||||
# verify ids are unique between lists
|
||||
intersection = input_set & output_set
|
||||
if len(intersection) > 0:
|
||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
||||
if len(issues) > 0:
|
||||
raise ValueError("\n".join(issues))
|
||||
# validate inputs and outputs
|
||||
if self.inputs is not None:
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
if self.outputs is not None:
|
||||
for output in self.outputs:
|
||||
output.validate()
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
for output in self.outputs:
|
||||
output.validate()
|
||||
|
||||
def finalize(self):
|
||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||
# ensure inputs, outputs, and hidden are lists
|
||||
if self.inputs is None:
|
||||
self.inputs = []
|
||||
if self.outputs is None:
|
||||
self.outputs = []
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
# if is an api_node, will need key-related hidden
|
||||
if self.is_api_node:
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
if Hidden.auth_token_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||
if Hidden.api_key_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.api_key_comfy_org)
|
||||
# if is an output_node, will need prompt and extra_pnginfo
|
||||
if self.is_output_node:
|
||||
if self.hidden is None:
|
||||
self.hidden = []
|
||||
if Hidden.prompt not in self.hidden:
|
||||
self.hidden.append(Hidden.prompt)
|
||||
if Hidden.extra_pnginfo not in self.hidden:
|
||||
self.hidden.append(Hidden.extra_pnginfo)
|
||||
# give outputs without ids default ids
|
||||
if self.outputs is not None:
|
||||
for i, output in enumerate(self.outputs):
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
for i, output in enumerate(self.outputs):
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
|
||||
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
||||
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||
# get V1 inputs
|
||||
input = create_input_dict_v1(self.inputs, live_inputs)
|
||||
input = create_input_dict_v1(self.inputs)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
@ -1385,33 +1411,54 @@ class Schema:
|
||||
)
|
||||
return info
|
||||
|
||||
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||
out_dict = {
|
||||
"required": {},
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
hidden = d.pop("hidden", None)
|
||||
parse_class_inputs(out_dict, live_inputs, d)
|
||||
if hidden is not None and include_hidden:
|
||||
out_dict["hidden"] = hidden
|
||||
v3_data = {}
|
||||
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||
if dynamic_paths is not None:
|
||||
v3_data["dynamic_paths"] = dynamic_paths
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
for input_type, inner_d in curr_dict.items():
|
||||
for id, value in inner_d.items():
|
||||
io_type = value[0]
|
||||
if io_type in DYNAMIC_INPUT_LOOKUP:
|
||||
# dynamic inputs need to be handled with lookup functions
|
||||
dynamic_input_func = get_dynamic_input_func(io_type)
|
||||
new_prefix = handle_prefix(curr_prefix, id)
|
||||
dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix)
|
||||
else:
|
||||
# non-dynamic inputs get directly transferred
|
||||
finalized_id = finalize_prefix(curr_prefix, id)
|
||||
out_dict[input_type][finalized_id] = value
|
||||
if curr_prefix:
|
||||
out_dict["dynamic_paths"][finalized_id] = finalized_id
|
||||
|
||||
def create_input_dict_v1(inputs: list[Input]) -> dict:
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
add_to_input_dict_v1(input, inputs, live_inputs)
|
||||
for i in inputs:
|
||||
add_to_dict_v1(i, input)
|
||||
return input
|
||||
|
||||
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
||||
for i in inputs:
|
||||
if isinstance(i, DynamicInput):
|
||||
add_to_dict_v1(i, d)
|
||||
if live_inputs is not None:
|
||||
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||
else:
|
||||
add_to_dict_v1(i, d)
|
||||
|
||||
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
||||
def add_to_dict_v1(i: Input, d: dict):
|
||||
key = "optional" if i.optional else "required"
|
||||
as_dict = i.as_dict()
|
||||
# for v1, we don't want to include the optional key
|
||||
as_dict.pop("optional", None)
|
||||
if dynamic_dict is None:
|
||||
value = (i.get_io_type(), as_dict)
|
||||
else:
|
||||
value = (i.get_io_type(), as_dict, dynamic_dict)
|
||||
d.setdefault(key, {})[i.id] = value
|
||||
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
@ -1423,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
values = values.copy()
|
||||
result = {}
|
||||
|
||||
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||
|
||||
for key, path in paths.items():
|
||||
parts = path.split(".")
|
||||
current = result
|
||||
@ -1431,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
is_last = (i == len(parts) - 1)
|
||||
|
||||
if is_last:
|
||||
current[p] = values.pop(key, None)
|
||||
value = values.pop(key, None)
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
else:
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
@ -1446,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
SCHEMA = None
|
||||
|
||||
# filled in during execution
|
||||
resources: Resources = None
|
||||
hidden: HiddenHolder = None
|
||||
|
||||
@classmethod
|
||||
@ -1493,7 +1544,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
return [name for name in kwargs if kwargs[name] is None]
|
||||
|
||||
def __init__(self):
|
||||
self.local_resources: ResourcesLocal = None
|
||||
self.__class__.VALIDATE_CLASS()
|
||||
|
||||
@classmethod
|
||||
@ -1556,12 +1606,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
||||
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
|
||||
"""Creates clone of real node class to prevent monkey-patching."""
|
||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
||||
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@ -1678,19 +1728,10 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
schema = cls.FINALIZE_SCHEMA()
|
||||
info = schema.get_v1_info(cls, live_inputs)
|
||||
input = info.input
|
||||
if not include_hidden:
|
||||
input.pop("hidden", None)
|
||||
if return_schema:
|
||||
v3_data: V3Data = {}
|
||||
dynamic = input.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
return input, schema, v3_data
|
||||
return input
|
||||
info = schema.get_v1_info(cls)
|
||||
return info.input
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
@ -1809,7 +1850,7 @@ class NodeOutput(_NodeOutputInternal):
|
||||
return self.args if len(self.args) > 0 else None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "NodeOutput":
|
||||
def from_dict(cls, data: dict[str, Any]) -> NodeOutput:
|
||||
args = ()
|
||||
ui = None
|
||||
expand = None
|
||||
@ -1904,8 +1945,8 @@ __all__ = [
|
||||
"Tracks",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
# "DynamicCombo",
|
||||
# "Autogrow",
|
||||
"DynamicCombo",
|
||||
"Autogrow",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class ResourceKey(ABC):
|
||||
Type = Any
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
class TorchDictFolderFilename(ResourceKey):
|
||||
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||
Type = dict[str, torch.Tensor]
|
||||
def __init__(self, folder_name: str, file_name: str):
|
||||
self.folder_name = folder_name
|
||||
self.file_name = file_name
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.folder_name, self.file_name))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, TorchDictFolderFilename):
|
||||
return False
|
||||
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.folder_name} -> {self.file_name}"
|
||||
|
||||
class Resources(ABC):
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
pass
|
||||
|
||||
class ResourcesLocal(Resources):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.local_resources: dict[ResourceKey, Any] = {}
|
||||
|
||||
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||
cached = self.local_resources.get(key, None)
|
||||
if cached is not None:
|
||||
logging.info(f"Using cached resource '{key}'")
|
||||
return cached
|
||||
logging.info(f"Loading resource '{key}'")
|
||||
to_return = None
|
||||
if isinstance(key, TorchDictFolderFilename):
|
||||
if default is ...:
|
||||
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||
else:
|
||||
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||
if full_path is not None:
|
||||
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||
|
||||
if to_return is not None:
|
||||
self.local_resources[key] = to_return
|
||||
return to_return
|
||||
if default is not ...:
|
||||
return default
|
||||
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||
|
||||
|
||||
class _RESOURCES:
|
||||
ResourceKey = ResourceKey
|
||||
TorchDictFolderFilename = TorchDictFolderFilename
|
||||
Resources = Resources
|
||||
ResourcesLocal = ResourcesLocal
|
||||
@ -1,5 +1,6 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
@ -8,4 +9,5 @@ __all__ = [
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"SVG",
|
||||
]
|
||||
|
||||
18
comfy_api/latest/_util/image_types.py
Normal file
18
comfy_api/latest/_util/image_types.py
Normal file
@ -0,0 +1,18 @@
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class SVG:
|
||||
"""Stores SVG representations via a list of BytesIO objects."""
|
||||
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel):
|
||||
size: str | None = Field(None)
|
||||
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
watermark: bool | None = Field(False)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel):
|
||||
size: str | None = Field("adaptive")
|
||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
watermark: bool | None = Field(False)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel):
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
|
||||
@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
|
||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
uploadImagesToStorage: bool = Field(True)
|
||||
|
||||
|
||||
class GeminiGenerateContentRequest(BaseModel):
|
||||
|
||||
@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class MotionControlRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
video_url: str = Field(...)
|
||||
keep_original_sound: str = Field(...)
|
||||
character_orientation: str = Field(...)
|
||||
mode: str = Field(..., description="'pro' or 'std'")
|
||||
|
||||
52
comfy_api_nodes/apis/openai_api.py
Normal file
52
comfy_api_nodes/apis/openai_api.py
Normal file
@ -0,0 +1,52 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Datum2(BaseModel):
|
||||
b64_json: str | None = Field(None, description="Base64 encoded image data")
|
||||
revised_prompt: str | None = Field(None, description="Revised prompt")
|
||||
url: str | None = Field(None, description="URL of the image")
|
||||
|
||||
|
||||
class InputTokensDetails(BaseModel):
|
||||
image_tokens: int | None = None
|
||||
text_tokens: int | None = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
input_tokens_details: InputTokensDetails | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class OpenAIImageGenerationResponse(BaseModel):
|
||||
data: list[Datum2] | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
class OpenAIImageEditRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str = Field(...)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(None, description="The number of images to generate")
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
size: str | None = Field(None, description="Size of the output image")
|
||||
|
||||
|
||||
class OpenAIImageGenerationRequest(BaseModel):
|
||||
background: str | None = Field(None, description="Background transparency")
|
||||
model: str | None = Field(None)
|
||||
moderation: str | None = Field(None)
|
||||
n: int | None = Field(
|
||||
None,
|
||||
description="The number of images to generate.",
|
||||
)
|
||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||
output_format: str | None = Field(None)
|
||||
prompt: str = Field(...)
|
||||
quality: str | None = Field(None, description="The quality of the generated image")
|
||||
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
|
||||
@ -1,10 +1,8 @@
|
||||
from inspect import cleandoc
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
def convert_mask_to_image(mask: Input.Image):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
|
||||
|
||||
|
||||
class FluxProUltraImageNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
node_id="FluxProUltraImageNode",
|
||||
display_name="Flux 1.1 [pro] Ultra Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
prompt_upsampling: bool = False,
|
||||
raw: bool = False,
|
||||
seed: int = 0,
|
||||
image_prompt: torch.Tensor | None = None,
|
||||
image_prompt: Input.Image | None = None,
|
||||
image_prompt_strength: float = 0.1,
|
||||
) -> IO.NodeOutput:
|
||||
if image_prompt is None:
|
||||
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class FluxKontextProImageNode(IO.ComfyNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: torch.Tensor | None = None,
|
||||
input_image: Input.Image | None = None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
) -> IO.NodeOutput:
|
||||
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
"""
|
||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
DESCRIPTION = cleandoc(__doc__ or "")
|
||||
DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
|
||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||
NODE_ID = "FluxKontextMaxImageNode"
|
||||
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
|
||||
|
||||
|
||||
class FluxProExpandNode(IO.ComfyNode):
|
||||
"""
|
||||
Outpaints image based on prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
node_id="FluxProExpandNode",
|
||||
display_name="Flux.1 Expand Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Outpaints image based on prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
top: int,
|
||||
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class FluxProFillNode(IO.ComfyNode):
|
||||
"""
|
||||
Inpaints image based on mask and prompt.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
node_id="FluxProFillNode",
|
||||
display_name="Flux.1 Fill Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Inpaints image based on mask and prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
image: Input.Image,
|
||||
mask: Input.Image,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
steps: int,
|
||||
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
|
||||
class Flux2ProImageNode(IO.ComfyNode):
|
||||
|
||||
NODE_ID = "Flux2ProImageNode"
|
||||
DISPLAY_NAME = "Flux.2 [pro] Image"
|
||||
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Flux2ProImageNode",
|
||||
display_name="Flux.2 [pro] Image",
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="api node/image/BFL",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
default=True,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
"If active, automatically modifies the prompt for more creative generation.",
|
||||
),
|
||||
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
|
||||
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
height: int,
|
||||
seed: int,
|
||||
prompt_upsampling: bool,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
reference_images = {}
|
||||
if images is not None:
|
||||
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
|
||||
ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=Flux2ProGenerateRequest(
|
||||
prompt=prompt,
|
||||
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class Flux2MaxImageNode(Flux2ProImageNode):
|
||||
|
||||
NODE_ID = "Flux2MaxImageNode"
|
||||
DISPLAY_NAME = "Flux.2 [max] Image"
|
||||
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
|
||||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
Flux2ProImageNode,
|
||||
Flux2MaxImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
optional=True,
|
||||
),
|
||||
@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
optional=True,
|
||||
),
|
||||
@ -229,6 +229,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNode",
|
||||
display_name="ByteDance Seedream 4",
|
||||
display_name="ByteDance Seedream 4.5",
|
||||
category="api node/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
@ -346,7 +347,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||
optional=True,
|
||||
),
|
||||
@ -380,7 +381,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
sequential_image_generation: str = "disabled",
|
||||
max_images: int = 1,
|
||||
seed: int = 0,
|
||||
watermark: bool = True,
|
||||
watermark: bool = False,
|
||||
fail_on_partial: bool = True,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
@ -507,7 +508,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
@ -617,7 +618,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
@ -739,7 +740,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
@ -862,7 +863,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
default=False,
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
|
||||
@ -34,6 +34,7 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
)
|
||||
parts = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
if part_type == "text" and part.text:
|
||||
parts.append(part)
|
||||
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
|
||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
elif part.fileData and part.fileData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
else:
|
||||
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
return torch.zeros((1, 1024, 1024, 4))
|
||||
@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiImage2(IO.ComfyNode):
|
||||
@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
||||
@ -51,6 +51,7 @@ from comfy_api_nodes.apis import (
|
||||
)
|
||||
from comfy_api_nodes.apis.kling_api import (
|
||||
ImageToVideoWithAudioRequest,
|
||||
MotionControlRequest,
|
||||
OmniImageParamImage,
|
||||
OmniParamImage,
|
||||
OmniParamVideo,
|
||||
@ -806,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -825,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
@ -836,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -858,7 +862,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("duration", options=["5", "10"]),
|
||||
IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input(
|
||||
"end_frame",
|
||||
@ -871,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
optional=True,
|
||||
tooltip="Up to 6 additional reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -892,11 +897,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image | None = None,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
if end_frame is not None and reference_images is not None:
|
||||
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||
if duration not in (5, 10) and end_frame is None and reference_images is None:
|
||||
raise ValueError(
|
||||
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
|
||||
)
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||
image_list: list[OmniParamImage] = [
|
||||
@ -931,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
prompt=prompt,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -959,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
"reference_images",
|
||||
tooltip="Up to 7 reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -979,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_images: Input.Image,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@ -1000,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -1031,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -1053,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@ -1085,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
duration=str(duration),
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -1114,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -1134,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@ -1166,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
duration=None,
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -2159,6 +2179,91 @@ class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class MotionControl(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingMotionControl",
|
||||
display_name="Kling Motion Control",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.Image.Input("reference_image"),
|
||||
IO.Video.Input(
|
||||
"reference_video",
|
||||
tooltip="Motion reference video used to drive movement/expression.\n"
|
||||
"Duration limits depend on character_orientation:\n"
|
||||
" - image: 3–10s (max 10s)\n"
|
||||
" - video: 3–30s (max 30s)",
|
||||
),
|
||||
IO.Boolean.Input("keep_original_sound", default=True),
|
||||
IO.Combo.Input(
|
||||
"character_orientation",
|
||||
options=["video", "image"],
|
||||
tooltip="Controls where the character's facing/orientation comes from.\n"
|
||||
"video: movements, expressions, camera moves, and orientation "
|
||||
"follow the motion reference video (other details via prompt).\n"
|
||||
"image: movements and expressions still follow the motion reference video, "
|
||||
"but the character orientation matches the reference image (camera/other details via prompt).",
|
||||
),
|
||||
IO.Combo.Input("mode", options=["pro", "std"]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
reference_image: Input.Image,
|
||||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
character_orientation: str,
|
||||
mode: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2500)
|
||||
validate_image_dimensions(reference_image, min_width=340, min_height=340)
|
||||
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
|
||||
if character_orientation == "image":
|
||||
validate_video_duration(reference_video, min_duration=3, max_duration=10)
|
||||
else:
|
||||
validate_video_duration(reference_video, min_duration=3, max_duration=30)
|
||||
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=MotionControlRequest(
|
||||
prompt=prompt,
|
||||
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
|
||||
video_url=await upload_video_to_comfyapi(cls, reference_video),
|
||||
keep_original_sound="yes" if keep_original_sound else "no",
|
||||
character_orientation=character_orientation,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class KlingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -2184,6 +2289,7 @@ class KlingExtension(ComfyExtension):
|
||||
OmniProImageNode,
|
||||
TextToVideoWithAudio,
|
||||
ImageToVideoWithAudio,
|
||||
MotionControl,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,46 +1,45 @@
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import os
|
||||
from enum import Enum
|
||||
from inspect import cleandoc
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import folder_paths
|
||||
import base64
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import (
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
OpenAICreateResponse,
|
||||
OpenAIResponse,
|
||||
CreateModelResponseProperties,
|
||||
Item,
|
||||
OutputContent,
|
||||
InputImageContent,
|
||||
Detail,
|
||||
InputTextContent,
|
||||
InputMessage,
|
||||
InputMessageContentList,
|
||||
InputContent,
|
||||
InputFileContent,
|
||||
InputImageContent,
|
||||
InputMessage,
|
||||
InputMessageContentList,
|
||||
InputTextContent,
|
||||
Item,
|
||||
OpenAICreateResponse,
|
||||
OpenAIResponse,
|
||||
OutputContent,
|
||||
)
|
||||
from comfy_api_nodes.apis.openai_api import (
|
||||
OpenAIImageEditRequest,
|
||||
OpenAIImageGenerationRequest,
|
||||
OpenAIImageGenerationResponse,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.util import (
|
||||
downscale_image_tensor,
|
||||
download_url_to_bytesio,
|
||||
validate_string,
|
||||
tensor_to_base64_string,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
download_url_to_bytesio,
|
||||
downscale_image_tensor,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||
|
||||
@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
||||
|
||||
|
||||
class OpenAIDalle2(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
||||
node_id="OpenAIDalle2",
|
||||
display_name="OpenAI DALL·E 2",
|
||||
category="api node/image/OpenAI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode):
|
||||
|
||||
|
||||
class OpenAIDalle3(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
||||
node_id="OpenAIDalle3",
|
||||
display_name="OpenAI DALL·E 3",
|
||||
category="api node/image/OpenAI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode):
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
|
||||
# https://platform.openai.com/docs/pricing
|
||||
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
|
||||
|
||||
|
||||
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
|
||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
||||
|
||||
|
||||
class OpenAIGPTImage1(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 1",
|
||||
category="api node/image/OpenAI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for GPT Image 1",
|
||||
tooltip="Text prompt for GPT Image",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"background",
|
||||
default="opaque",
|
||||
options=["opaque", "transparent"],
|
||||
default="auto",
|
||||
options=["auto", "opaque", "transparent"],
|
||||
tooltip="Return image with or without background",
|
||||
optional=True,
|
||||
),
|
||||
@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-1", "gpt-image-1.5"],
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt,
|
||||
seed=0,
|
||||
quality="low",
|
||||
background="opaque",
|
||||
image=None,
|
||||
mask=None,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
prompt: str,
|
||||
seed: int = 0,
|
||||
quality: str = "low",
|
||||
background: str = "opaque",
|
||||
image: Input.Image | None = None,
|
||||
mask: Input.Image | None = None,
|
||||
n: int = 1,
|
||||
size: str = "1024x1024",
|
||||
model: str = "gpt-image-1",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
model = "gpt-image-1"
|
||||
path = "/proxy/openai/images/generations"
|
||||
content_type = "application/json"
|
||||
request_class = OpenAIImageGenerationRequest
|
||||
files = []
|
||||
|
||||
if mask is not None and image is None:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if model == "gpt-image-1":
|
||||
price_extractor = calculate_tokens_price_image_1
|
||||
elif model == "gpt-image-1.5":
|
||||
price_extractor = calculate_tokens_price_image_1_5
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
|
||||
if image is not None:
|
||||
path = "/proxy/openai/images/edits"
|
||||
request_class = OpenAIImageEditRequest
|
||||
content_type = "multipart/form-data"
|
||||
|
||||
files = []
|
||||
batch_size = image.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
single_image = image[i : i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image).squeeze()
|
||||
single_image = image[i: i + 1]
|
||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
|
||||
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
else:
|
||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||
|
||||
if mask is not None:
|
||||
if image is None:
|
||||
raise Exception("Cannot use a mask without an input image")
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
batch, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
if mask is not None:
|
||||
if image.shape[0] != 1:
|
||||
raise Exception("Cannot use a mask with multiple image")
|
||||
if mask.shape[1:] != image.shape[1:-1]:
|
||||
raise Exception("Mask and Image must be the same size")
|
||||
_, height, width = mask.shape
|
||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
|
||||
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
# Build the operation
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=path, method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=request_class(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
),
|
||||
files=files if files else None,
|
||||
content_type=content_type,
|
||||
)
|
||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_np)
|
||||
mask_img_byte_arr = BytesIO()
|
||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||
mask_img_byte_arr.seek(0)
|
||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageEditRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
files=files,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
else:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||
response_model=OpenAIImageGenerationResponse,
|
||||
data=OpenAIImageGenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
background=background,
|
||||
n=n,
|
||||
seed=seed,
|
||||
size=size,
|
||||
moderation="low",
|
||||
),
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||
|
||||
|
||||
|
||||
@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
}
|
||||
UPSCALER_VALUES_MAP = {
|
||||
"FullHD (1080p)": 1920,
|
||||
"4K (2160p)": 3840,
|
||||
}
|
||||
|
||||
|
||||
class TopazImageEnhance(IO.ComfyNode):
|
||||
@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
IO.Video.Input("video"),
|
||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Combo.Input(
|
||||
"upscaler_creativity",
|
||||
options=["low", "middle", "high"],
|
||||
@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
target_frame_rate = src_frame_rate
|
||||
filters = []
|
||||
if upscaler_enabled:
|
||||
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
if "1080p" in upscaler_resolution:
|
||||
target_pixel_p = 1080
|
||||
max_long_side = 1920
|
||||
else:
|
||||
target_pixel_p = 2160
|
||||
max_long_side = 3840
|
||||
ar = src_width / src_height
|
||||
if src_width >= src_height:
|
||||
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
|
||||
target_height = target_pixel_p
|
||||
target_width = int(target_height * ar)
|
||||
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
|
||||
if target_width > max_long_side:
|
||||
target_width = max_long_side
|
||||
target_height = int(target_width / ar)
|
||||
else:
|
||||
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
|
||||
target_width = target_pixel_p
|
||||
target_height = int(target_width / ar)
|
||||
# Check if height exceeds standard bounds
|
||||
if target_height > max_long_side:
|
||||
target_height = max_long_side
|
||||
target_width = int(target_height * ar)
|
||||
if target_width % 2 != 0:
|
||||
target_width += 1
|
||||
if target_height % 2 != 0:
|
||||
target_height += 1
|
||||
filters.append(
|
||||
topaz_api.VideoEnhancementFilter(
|
||||
model=UPSCALER_MODELS_MAP[upscaler_model],
|
||||
|
||||
@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
geometry_quality=geometry_quality,
|
||||
auto_size=True,
|
||||
quad=quad,
|
||||
@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
texture_alignment=texture_alignment,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
auto_size=True,
|
||||
quad=quad,
|
||||
),
|
||||
@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
texture_quality=texture_quality,
|
||||
geometry_quality=geometry_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
quad=quad,
|
||||
),
|
||||
)
|
||||
|
||||
@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
# Only add generateAudio for Veo 3 models
|
||||
if model.find("veo-2.0") == -1:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
# force "enhance_prompt" to True for Veo3 models
|
||||
parameters["enhancePrompt"] = True
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
IO.Boolean.Input(
|
||||
"enhance_prompt",
|
||||
default=True,
|
||||
tooltip="Whether to enhance the prompt with AI assistance",
|
||||
tooltip="This parameter is deprecated and ignored.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
@ -21,26 +19,26 @@ from comfy_api_nodes.util import (
|
||||
|
||||
class Text2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
negative_prompt: str | None = Field(None)
|
||||
|
||||
|
||||
class Image2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
negative_prompt: str | None = Field(None)
|
||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||
|
||||
|
||||
class Text2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
audio_url: Optional[str] = Field(None)
|
||||
negative_prompt: str | None = Field(None)
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Image2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
negative_prompt: str | None = Field(None)
|
||||
img_url: str = Field(...)
|
||||
audio_url: Optional[str] = Field(None)
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Txt2ImageParametersField(BaseModel):
|
||||
@ -48,32 +46,34 @@ class Txt2ImageParametersField(BaseModel):
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Image2ImageParametersField(BaseModel):
|
||||
size: Optional[str] = Field(None)
|
||||
size: str | None = Field(None)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Text2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=10)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(True)
|
||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
||||
watermark: bool = Field(False)
|
||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Image2VideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=10)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(True)
|
||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
||||
watermark: bool = Field(False)
|
||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel):
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
output: Optional[TaskCreationOutputField] = Field(None)
|
||||
output: TaskCreationOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
code: Optional[str] = Field(None, description="The error code of the failed request.")
|
||||
message: Optional[str] = Field(None, description="Details of the failed request.")
|
||||
code: str | None = Field(None, description="Error code for the failed request.")
|
||||
message: str | None = Field(None, description="Details about the failed request.")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
url: Optional[str] = Field(None)
|
||||
code: Optional[str] = Field(None)
|
||||
message: Optional[str] = Field(None)
|
||||
url: str | None = Field(None)
|
||||
code: str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
results: Optional[list[TaskResult]] = Field(None)
|
||||
results: list[TaskResult] | None = Field(None)
|
||||
|
||||
|
||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
video_url: Optional[str] = Field(None)
|
||||
code: Optional[str] = Field(None)
|
||||
message: Optional[str] = Field(None)
|
||||
video_url: str | None = Field(None)
|
||||
code: str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusResponse(BaseModel):
|
||||
output: Optional[ImageTaskStatusOutputField] = Field(None)
|
||||
output: ImageTaskStatusOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
class VideoTaskStatusResponse(BaseModel):
|
||||
output: Optional[VideoTaskStatusOutputField] = Field(None)
|
||||
output: VideoTaskStatusOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
node_id="WanTextToImageApi",
|
||||
display_name="Wan Text to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates image based on text prompt.",
|
||||
description="Generates an image based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
tooltip="Negative prompt describing what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -208,8 +208,8 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
default=False,
|
||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
height: int = 1024,
|
||||
seed: int = 0,
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
watermark: bool = False,
|
||||
):
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
display_name="Wan Image to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
||||
tooltip="Single-image editing or multi-image fusion. Maximum 2 images.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
tooltip="Negative prompt describing what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
# redo this later as an optional combo of recommended resolutions
|
||||
@ -327,8 +327,8 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
default=False,
|
||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -347,17 +347,17 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
# width: int = 1024,
|
||||
# height: int = 1024,
|
||||
seed: int = 0,
|
||||
watermark: bool = True,
|
||||
watermark: bool = False,
|
||||
):
|
||||
n_images = get_number_of_images(image)
|
||||
if n_images not in (1, 2):
|
||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
||||
raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.")
|
||||
images = []
|
||||
for i in image:
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||
@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
node_id="WanTextToVideoApi",
|
||||
display_name="Wan Text to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Generates video based on text prompt.",
|
||||
description="Generates a video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-t2v-preview"],
|
||||
default="wan2.5-t2v-preview",
|
||||
options=["wan2.5-t2v-preview", "wan2.6-t2v"],
|
||||
default="wan2.6-t2v",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
tooltip="Negative prompt describing what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
"1080p: 4:3 (1632x1248)",
|
||||
"1080p: 3:4 (1248x1632)",
|
||||
],
|
||||
default="480p: 1:1 (624x624)",
|
||||
default="720p: 1:1 (960x960)",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
max=15,
|
||||
step=5,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Available durations: 5 and 10 seconds",
|
||||
tooltip="A 15-second duration is available only for the Wan 2.6 model.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="If there is no audio input, generate audio automatically.",
|
||||
tooltip="If no audio input is provided, generate audio automatically.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_extend",
|
||||
@ -476,8 +476,16 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
default=False,
|
||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"shot_type",
|
||||
options=["single", "multi"],
|
||||
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||
"single continuous shot or multiple shots with cuts. "
|
||||
"This parameter takes effect only when prompt_extend is True.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -498,14 +506,19 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
model: str,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
size: str = "480p: 1:1 (624x624)",
|
||||
size: str = "720p: 1:1 (960x960)",
|
||||
duration: int = 5,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
audio: Input.Audio | None = None,
|
||||
seed: int = 0,
|
||||
generate_audio: bool = False,
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
watermark: bool = False,
|
||||
shot_type: str = "single",
|
||||
):
|
||||
if "480p" in size and model == "wan2.6-t2v":
|
||||
raise ValueError("The Wan 2.6 model does not support 480p.")
|
||||
if duration == 15 and model == "wan2.5-t2v-preview":
|
||||
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||
width, height = RES_IN_PARENS.search(size).groups()
|
||||
audio_url = None
|
||||
if audio is not None:
|
||||
@ -526,11 +539,12 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
shot_type=shot_type,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
@ -549,12 +563,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
node_id="WanImageToVideoApi",
|
||||
display_name="Wan Image to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Generates video based on the first frame and text prompt.",
|
||||
description="Generates a video from the first frame and a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["wan2.5-i2v-preview"],
|
||||
default="wan2.5-i2v-preview",
|
||||
options=["wan2.5-i2v-preview", "wan2.6-i2v"],
|
||||
default="wan2.6-i2v",
|
||||
tooltip="Model to use.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
@ -564,13 +578,13 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid.",
|
||||
tooltip="Negative prompt describing what to avoid.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@ -580,23 +594,23 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
"720P",
|
||||
"1080P",
|
||||
],
|
||||
default="480P",
|
||||
default="720P",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
max=15,
|
||||
step=5,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Available durations: 5 and 10 seconds",
|
||||
tooltip="Duration 15 available only for WAN2.6 model.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -613,7 +627,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="If there is no audio input, generate audio automatically.",
|
||||
tooltip="If no audio input is provided, generate audio automatically.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_extend",
|
||||
@ -623,8 +637,16 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
default=False,
|
||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"shot_type",
|
||||
options=["single", "multi"],
|
||||
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||
"single continuous shot or multiple shots with cuts. "
|
||||
"This parameter takes effect only when prompt_extend is True.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -643,19 +665,24 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
resolution: str = "480P",
|
||||
resolution: str = "720P",
|
||||
duration: int = 5,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
audio: Input.Audio | None = None,
|
||||
seed: int = 0,
|
||||
generate_audio: bool = False,
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
watermark: bool = False,
|
||||
shot_type: str = "single",
|
||||
):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
if "480P" in resolution and model == "wan2.6-i2v":
|
||||
raise ValueError("The Wan 2.6 model does not support 480P.")
|
||||
if duration == 15 and model == "wan2.5-i2v-preview":
|
||||
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||
audio_url = None
|
||||
if audio is not None:
|
||||
@ -677,11 +704,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
shot_type=shot_type,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
|
||||
@ -1,16 +1,22 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
from .common_exceptions import ProcessingInterrupted
|
||||
|
||||
_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits
|
||||
_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits
|
||||
|
||||
|
||||
def is_processing_interrupted() -> bool:
|
||||
"""Return True if user/runtime requested interruption."""
|
||||
@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
|
||||
|
||||
def to_aiohttp_url(url: str) -> URL:
|
||||
"""If `url` appears to be already percent-encoded (contains at least one valid %HH
|
||||
escape and no malformed '%' sequences) and contains no raw whitespace/control
|
||||
characters preserve the original encoding byte-for-byte (important for signed/presigned URLs).
|
||||
Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed."""
|
||||
if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url):
|
||||
# Avoid encoded=True if URL contains raw whitespace/control chars
|
||||
return URL(url)
|
||||
if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url):
|
||||
# Preserve encoding only if it appears pre-encoded AND has no invalid % sequences
|
||||
return URL(url, encoded=True)
|
||||
return URL(url)
|
||||
|
||||
@ -430,9 +430,9 @@ def _display_text(
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
if price is not None:
|
||||
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
||||
p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
|
||||
if p != "0":
|
||||
display_lines.append(f"Price: ${p}")
|
||||
display_lines.append(f"Price: {p} credits")
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
|
||||
@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
|
||||
@ -19,6 +19,7 @@ from ._helpers import (
|
||||
get_auth_header,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
to_aiohttp_url,
|
||||
)
|
||||
from .client import _diagnose_connectivity
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
@ -94,7 +95,7 @@ async def download_url_to_bytesio(
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
|
||||
req_task = asyncio.create_task(session.get(url, headers=headers))
|
||||
req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers))
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
|
||||
@ -97,6 +97,11 @@ def get_input_info(
|
||||
extra_info = input_info[1]
|
||||
else:
|
||||
extra_info = {}
|
||||
# if input_type is a list, it is a Combo defined in outdated format; convert it.
|
||||
# NOTE: uncomment this when we are confident old format going away won't cause too much trouble.
|
||||
# if isinstance(input_type, list):
|
||||
# extra_info["options"] = input_type
|
||||
# input_type = IO.Combo.io_type
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
class TopologicalSort:
|
||||
@ -202,15 +207,15 @@ class ExecutionList(TopologicalSort):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if not from_node_id in self.execution_cache_listeners:
|
||||
if from_node_id not in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
|
||||
def get_cache(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
return None
|
||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||
if value is None:
|
||||
|
||||
291
comfy_execution/jobs.py
Normal file
291
comfy_execution/jobs.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""
|
||||
Job utilities for the /api/jobs endpoint.
|
||||
Provides normalization and helper functions for job status tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from comfy_api.internal import prune_dict
|
||||
|
||||
|
||||
class JobStatus:
|
||||
"""Job status constants."""
|
||||
PENDING = 'pending'
|
||||
IN_PROGRESS = 'in_progress'
|
||||
COMPLETED = 'completed'
|
||||
FAILED = 'failed'
|
||||
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
||||
|
||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
"""Extract create_time and workflow_id from extra_data.
|
||||
|
||||
Returns:
|
||||
tuple: (create_time, workflow_id)
|
||||
"""
|
||||
create_time = extra_data.get('create_time')
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
||||
return create_time, workflow_id
|
||||
|
||||
|
||||
def is_previewable(media_type: str, item: dict) -> bool:
|
||||
"""
|
||||
Check if an output item is previewable.
|
||||
Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts
|
||||
Maintains backwards compatibility with existing logic.
|
||||
|
||||
Priority:
|
||||
1. media_type is 'images', 'video', or 'audio'
|
||||
2. format field starts with 'video/' or 'audio/'
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
||||
"""
|
||||
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
||||
return True
|
||||
|
||||
# Check format field (MIME type).
|
||||
# Maintains backwards compatibility with how custom node outputs are handled in the frontend.
|
||||
fmt = item.get('format', '')
|
||||
if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')):
|
||||
return True
|
||||
|
||||
# Check for 3D files by extension
|
||||
filename = item.get('filename', '').lower()
|
||||
if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def normalize_queue_item(item: tuple, status: str) -> dict:
|
||||
"""Convert queue item tuple to unified job dict.
|
||||
|
||||
Expects item with sensitive data already removed (5 elements).
|
||||
"""
|
||||
priority, prompt_id, _, extra_data, _ = item
|
||||
create_time, workflow_id = _extract_job_metadata(extra_data)
|
||||
|
||||
return prune_dict({
|
||||
'id': prompt_id,
|
||||
'status': status,
|
||||
'priority': priority,
|
||||
'create_time': create_time,
|
||||
'outputs_count': 0,
|
||||
'workflow_id': workflow_id,
|
||||
})
|
||||
|
||||
|
||||
def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict:
|
||||
"""Convert history item dict to unified job dict.
|
||||
|
||||
History items have sensitive data already removed (prompt tuple has 5 elements).
|
||||
"""
|
||||
prompt_tuple = history_item['prompt']
|
||||
priority, _, prompt, extra_data, _ = prompt_tuple
|
||||
create_time, workflow_id = _extract_job_metadata(extra_data)
|
||||
|
||||
status_info = history_item.get('status', {})
|
||||
status_str = status_info.get('status_str') if status_info else None
|
||||
if status_str == 'success':
|
||||
status = JobStatus.COMPLETED
|
||||
elif status_str == 'error':
|
||||
status = JobStatus.FAILED
|
||||
else:
|
||||
status = JobStatus.COMPLETED
|
||||
|
||||
outputs = history_item.get('outputs', {})
|
||||
outputs_count, preview_output = get_outputs_summary(outputs)
|
||||
|
||||
execution_error = None
|
||||
execution_start_time = None
|
||||
execution_end_time = None
|
||||
if status_info:
|
||||
messages = status_info.get('messages', [])
|
||||
for entry in messages:
|
||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
event_name, event_data = entry[0], entry[1]
|
||||
if isinstance(event_data, dict):
|
||||
if event_name == 'execution_start':
|
||||
execution_start_time = event_data.get('timestamp')
|
||||
elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'):
|
||||
execution_end_time = event_data.get('timestamp')
|
||||
if event_name == 'execution_error':
|
||||
execution_error = event_data
|
||||
|
||||
job = prune_dict({
|
||||
'id': prompt_id,
|
||||
'status': status,
|
||||
'priority': priority,
|
||||
'create_time': create_time,
|
||||
'execution_start_time': execution_start_time,
|
||||
'execution_end_time': execution_end_time,
|
||||
'execution_error': execution_error,
|
||||
'outputs_count': outputs_count,
|
||||
'preview_output': preview_output,
|
||||
'workflow_id': workflow_id,
|
||||
})
|
||||
|
||||
if include_outputs:
|
||||
job['outputs'] = outputs
|
||||
job['execution_status'] = status_info
|
||||
job['workflow'] = {
|
||||
'prompt': prompt,
|
||||
'extra_data': extra_data,
|
||||
}
|
||||
|
||||
return job
|
||||
|
||||
|
||||
def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
||||
"""
|
||||
Count outputs and find preview in a single pass.
|
||||
Returns (outputs_count, preview_output).
|
||||
|
||||
Preview priority (matching frontend):
|
||||
1. type="output" with previewable media
|
||||
2. Any previewable media
|
||||
"""
|
||||
count = 0
|
||||
preview_output = None
|
||||
fallback_preview = None
|
||||
|
||||
for node_id, node_outputs in outputs.items():
|
||||
if not isinstance(node_outputs, dict):
|
||||
continue
|
||||
for media_type, items in node_outputs.items():
|
||||
# 'animated' is a boolean flag, not actual output items
|
||||
if media_type == 'animated' or not isinstance(items, list):
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
count += 1
|
||||
|
||||
if preview_output is None and is_previewable(media_type, item):
|
||||
enriched = {
|
||||
**item,
|
||||
'nodeId': node_id,
|
||||
'mediaType': media_type
|
||||
}
|
||||
if item.get('type') == 'output':
|
||||
preview_output = enriched
|
||||
elif fallback_preview is None:
|
||||
fallback_preview = enriched
|
||||
|
||||
return count, preview_output or fallback_preview
|
||||
|
||||
|
||||
def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]:
|
||||
"""Sort jobs list by specified field and order."""
|
||||
reverse = (sort_order == 'desc')
|
||||
|
||||
if sort_by == 'execution_duration':
|
||||
def get_sort_key(job):
|
||||
start = job.get('execution_start_time', 0)
|
||||
end = job.get('execution_end_time', 0)
|
||||
return end - start if end and start else 0
|
||||
else:
|
||||
def get_sort_key(job):
|
||||
return job.get('create_time', 0)
|
||||
|
||||
return sorted(jobs, key=get_sort_key, reverse=reverse)
|
||||
|
||||
|
||||
def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]:
|
||||
"""
|
||||
Get a single job by prompt_id from history or queue.
|
||||
|
||||
Args:
|
||||
prompt_id: The prompt ID to look up
|
||||
running: List of currently running queue items
|
||||
queued: List of pending queue items
|
||||
history: Dict of history items keyed by prompt_id
|
||||
|
||||
Returns:
|
||||
Job dict with full details, or None if not found
|
||||
"""
|
||||
if prompt_id in history:
|
||||
return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True)
|
||||
|
||||
for item in running:
|
||||
if item[1] == prompt_id:
|
||||
return normalize_queue_item(item, JobStatus.IN_PROGRESS)
|
||||
|
||||
for item in queued:
|
||||
if item[1] == prompt_id:
|
||||
return normalize_queue_item(item, JobStatus.PENDING)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_all_jobs(
|
||||
running: list,
|
||||
queued: list,
|
||||
history: dict,
|
||||
status_filter: Optional[list[str]] = None,
|
||||
workflow_id: Optional[str] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get all jobs (running, pending, completed) with filtering and sorting.
|
||||
|
||||
Args:
|
||||
running: List of currently running queue items
|
||||
queued: List of pending queue items
|
||||
history: Dict of history items keyed by prompt_id
|
||||
status_filter: List of statuses to include (from JobStatus.ALL)
|
||||
workflow_id: Filter by workflow ID
|
||||
sort_by: Field to sort by ('created_at', 'execution_duration')
|
||||
sort_order: 'asc' or 'desc'
|
||||
limit: Maximum number of items to return
|
||||
offset: Number of items to skip
|
||||
|
||||
Returns:
|
||||
tuple: (jobs_list, total_count)
|
||||
"""
|
||||
jobs = []
|
||||
|
||||
if status_filter is None:
|
||||
status_filter = JobStatus.ALL
|
||||
|
||||
if JobStatus.IN_PROGRESS in status_filter:
|
||||
for item in running:
|
||||
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
|
||||
|
||||
if JobStatus.PENDING in status_filter:
|
||||
for item in queued:
|
||||
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
|
||||
|
||||
include_completed = JobStatus.COMPLETED in status_filter
|
||||
include_failed = JobStatus.FAILED in status_filter
|
||||
if include_completed or include_failed:
|
||||
for prompt_id, history_item in history.items():
|
||||
is_failed = history_item.get('status', {}).get('status_str') == 'error'
|
||||
if (is_failed and include_failed) or (not is_failed and include_completed):
|
||||
jobs.append(normalize_history_item(prompt_id, history_item))
|
||||
|
||||
if workflow_id:
|
||||
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
||||
|
||||
jobs = apply_sorting(jobs, sort_by, sort_order)
|
||||
|
||||
total_count = len(jobs)
|
||||
|
||||
if offset > 0:
|
||||
jobs = jobs[offset:]
|
||||
if limit is not None:
|
||||
jobs = jobs[:limit]
|
||||
|
||||
return (jobs, total_count)
|
||||
@ -21,14 +21,24 @@ def validate_node_input(
|
||||
"""
|
||||
# If the types are exactly the same, we can return immediately
|
||||
# Use pre-union behaviour: inverse of `__ne__`
|
||||
# NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class.
|
||||
if not received_type != input_type:
|
||||
return True
|
||||
|
||||
# If one of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||
if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type:
|
||||
return True
|
||||
|
||||
# If the received type or input_type is a MatchType, we can return True immediately;
|
||||
# validation for this is handled by the frontend
|
||||
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
||||
return True
|
||||
|
||||
# This accounts for some custom nodes that output lists of options as the type;
|
||||
# if we ever want to break them on purpose, this can be removed
|
||||
if isinstance(received_type, list) and input_type == IO.Combo.io_type:
|
||||
return True
|
||||
|
||||
# Not equal, and not strings
|
||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||
return False
|
||||
@ -37,6 +47,10 @@ def validate_node_input(
|
||||
received_types = set(t.strip() for t in received_type.split(","))
|
||||
input_types = set(t.strip() for t in input_type.split(","))
|
||||
|
||||
# If any of the types is '*', we can return True immediately; this is the 'Any' type.
|
||||
if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types:
|
||||
return True
|
||||
|
||||
if strict:
|
||||
# In strict mode, all received types must be in the input types
|
||||
return received_types.issubset(input_types)
|
||||
|
||||
@ -55,7 +55,8 @@ class APG(io.ComfyNode):
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
||||
if len(args["conds_out"]) == 1:
|
||||
return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
|
||||
@ -9,6 +9,7 @@ import comfy.utils
|
||||
import node_helpers
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import re
|
||||
|
||||
|
||||
class BasicScheduler(io.ComfyNode):
|
||||
@ -671,7 +672,16 @@ class SamplerSEEDS2(io.ComfyNode):
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||
],
|
||||
outputs=[io.Sampler.Output()]
|
||||
outputs=[io.Sampler.Output()],
|
||||
description=(
|
||||
"This sampler node can represent multiple samplers:\n\n"
|
||||
"seeds_2\n"
|
||||
"- default setting\n\n"
|
||||
"exp_heun_2_x0\n"
|
||||
"- solver_type=phi_2, r=1.0, eta=0.0\n\n"
|
||||
"exp_heun_2_x0_sde\n"
|
||||
"- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0"
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -751,8 +761,12 @@ class SamplerCustom(io.ComfyNode):
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||
if samples.is_nested:
|
||||
latent_shapes = [x.shape for x in samples.unbind()]
|
||||
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
|
||||
out_denoised = latent.copy()
|
||||
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||
out_denoised["samples"] = x0_out
|
||||
else:
|
||||
out_denoised = out
|
||||
return io.NodeOutput(out, out_denoised)
|
||||
@ -939,8 +953,12 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
||||
if samples.is_nested:
|
||||
latent_shapes = [x.shape for x in samples.unbind()]
|
||||
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
|
||||
out_denoised = latent.copy()
|
||||
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
||||
out_denoised["samples"] = x0_out
|
||||
else:
|
||||
out_denoised = out
|
||||
return io.NodeOutput(out, out_denoised)
|
||||
@ -996,6 +1014,25 @@ class AddNoise(io.ComfyNode):
|
||||
|
||||
add_noise = execute
|
||||
|
||||
class ManualSigmas(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
category="_for_testing/custom_sampling",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||
],
|
||||
outputs=[io.Sigmas.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, sigmas) -> io.NodeOutput:
|
||||
sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas)
|
||||
sigmas = [float(i) for i in sigmas]
|
||||
sigmas = torch.FloatTensor(sigmas)
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class CustomSamplersExtension(ComfyExtension):
|
||||
@override
|
||||
@ -1035,6 +1072,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
DisableNoise,
|
||||
AddNoise,
|
||||
SamplerCustomAdvanced,
|
||||
ManualSigmas,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, longer_edge):
|
||||
img = tensor_to_pil(image)
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
new_w = longer_edge
|
||||
new_h = int(h * (longer_edge / w))
|
||||
else:
|
||||
new_h = longer_edge
|
||||
new_w = int(w * (longer_edge / h))
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
return pil_to_tensor(img)
|
||||
resized_images = []
|
||||
for image_i in image:
|
||||
img = tensor_to_pil(image_i)
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
new_w = longer_edge
|
||||
new_h = int(h * (longer_edge / w))
|
||||
else:
|
||||
new_h = longer_edge
|
||||
new_w = int(w * (longer_edge / h))
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
resized_images.append(pil_to_tensor(img))
|
||||
return torch.cat(resized_images, dim=0)
|
||||
|
||||
|
||||
class CenterCropImagesNode(ImageProcessingNode):
|
||||
@ -1125,6 +1128,99 @@ class MergeTextListsNode(TextProcessingNode):
|
||||
# ========== Training Dataset Nodes ==========
|
||||
|
||||
|
||||
class ResolutionBucket(io.ComfyNode):
|
||||
"""Bucket latents and conditions by resolution for efficient batch training."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResolutionBucket",
|
||||
display_name="Resolution Bucket",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Latent.Input(
|
||||
"latents",
|
||||
tooltip="List of latent dicts to bucket by resolution.",
|
||||
),
|
||||
io.Conditioning.Input(
|
||||
"conditioning",
|
||||
tooltip="List of conditioning lists (must match latents length).",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(
|
||||
display_name="latents",
|
||||
is_output_list=True,
|
||||
tooltip="List of batched latent dicts, one per resolution bucket.",
|
||||
),
|
||||
io.Conditioning.Output(
|
||||
display_name="conditioning",
|
||||
is_output_list=True,
|
||||
tooltip="List of condition lists, one per resolution bucket.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents, conditioning):
|
||||
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
||||
# conditioning: list[list[cond]]
|
||||
|
||||
# Validate lengths match
|
||||
if len(latents) != len(conditioning):
|
||||
raise ValueError(
|
||||
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
||||
)
|
||||
|
||||
# Flatten latents and conditions to individual samples
|
||||
flat_latents = [] # list of (C, H, W) tensors
|
||||
flat_conditions = [] # list of condition lists
|
||||
|
||||
for latent_dict, cond in zip(latents, conditioning):
|
||||
samples = latent_dict["samples"] # (B, C, H, W)
|
||||
batch_size = samples.shape[0]
|
||||
|
||||
# cond is a list of conditions with length == batch_size
|
||||
for i in range(batch_size):
|
||||
flat_latents.append(samples[i]) # (C, H, W)
|
||||
flat_conditions.append(cond[i]) # single condition
|
||||
|
||||
# Group by resolution (H, W)
|
||||
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||
|
||||
for latent, cond in zip(flat_latents, flat_conditions):
|
||||
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
||||
h, w = latent.shape[-2], latent.shape[-1]
|
||||
key = (h, w)
|
||||
|
||||
if key not in buckets:
|
||||
buckets[key] = {"latents": [], "conditions": []}
|
||||
|
||||
buckets[key]["latents"].append(latent)
|
||||
buckets[key]["conditions"].append(cond)
|
||||
|
||||
# Convert buckets to output format
|
||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
||||
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||
|
||||
for (h, w), bucket_data in buckets.items():
|
||||
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
||||
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||
output_latents.append({"samples": stacked_latents})
|
||||
|
||||
# Conditions stay as list of condition lists
|
||||
output_conditions.append(bucket_data["conditions"])
|
||||
|
||||
logging.info(
|
||||
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
||||
)
|
||||
|
||||
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
||||
return io.NodeOutput(output_latents, output_conditions)
|
||||
|
||||
|
||||
class MakeTrainingDataset(io.ComfyNode):
|
||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||
|
||||
@ -1373,7 +1469,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
shard_path = os.path.join(dataset_dir, shard_file)
|
||||
|
||||
with open(shard_path, "rb") as f:
|
||||
shard_data = torch.load(f, weights_only=True)
|
||||
shard_data = torch.load(f)
|
||||
|
||||
all_latents.extend(shard_data["latents"])
|
||||
all_conditioning.extend(shard_data["conditioning"])
|
||||
@ -1425,6 +1521,7 @@ class DatasetExtension(ComfyExtension):
|
||||
MakeTrainingDataset,
|
||||
SaveTrainingDataset,
|
||||
LoadTrainingDataset,
|
||||
ResolutionBucket,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -2,280 +2,231 @@ from __future__ import annotations
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.comfy_types import FileLocator, IO
|
||||
from server import PromptServer
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
from typing_extensions import override
|
||||
|
||||
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
class ImageCrop:
|
||||
class ImageCrop(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "crop"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def crop(self, image, width, height, x, y):
|
||||
@classmethod
|
||||
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
|
||||
x = min(x, image.shape[2] - 1)
|
||||
y = min(y, image.shape[1] - 1)
|
||||
to_x = width + x
|
||||
to_y = height + y
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return (img,)
|
||||
return IO.NodeOutput(img)
|
||||
|
||||
class RepeatImageBatch:
|
||||
crop = execute # TODO: remove
|
||||
|
||||
|
||||
class RepeatImageBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RepeatImageBatch",
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("amount", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def repeat(self, image, amount):
|
||||
@classmethod
|
||||
def execute(cls, image, amount) -> IO.NodeOutput:
|
||||
s = image.repeat((amount, 1,1,1))
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
class ImageFromBatch:
|
||||
repeat = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageFromBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "frombatch"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFromBatch",
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("batch_index", default=0, min=0, max=4095),
|
||||
IO.Int.Input("length", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def frombatch(self, image, batch_index, length):
|
||||
@classmethod
|
||||
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
|
||||
s_in = image
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
frombatch = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageAddNoise:
|
||||
class ImageAddNoise(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
||||
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageAddNoise",
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def repeat(self, image, seed, strength):
|
||||
@classmethod
|
||||
def execute(cls, image, seed, strength) -> IO.NodeOutput:
|
||||
generator = torch.manual_seed(seed)
|
||||
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
repeat = execute # TODO: remove
|
||||
|
||||
methods = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"lossless": ("BOOLEAN", {"default": True}),
|
||||
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
|
||||
"method": (list(s.methods.keys()),),
|
||||
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||
method = self.methods.get(method)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results: list[FileLocator] = []
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = pil_images[0].getexif()
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
inital_exif = 0x010f
|
||||
for x in extra_pnginfo:
|
||||
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
|
||||
inital_exif -= 1
|
||||
|
||||
if num_frames == 0:
|
||||
num_frames = len(pil_images)
|
||||
|
||||
c = len(pil_images)
|
||||
for i in range(0, c, num_frames):
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
animated = num_frames != 1
|
||||
return { "ui": { "images": results, "animated": (animated,) } }
|
||||
|
||||
class SaveAnimatedPNG:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveAnimatedWEBP(IO.ComfyNode):
|
||||
COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAnimatedWEBP",
|
||||
category="image/animation",
|
||||
inputs=[
|
||||
IO.Image.Input("images"),
|
||||
IO.String.Input("filename_prefix", default="ComfyUI"),
|
||||
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
||||
IO.Boolean.Input("lossless", default=True),
|
||||
IO.Int.Input("quality", default=80, min=0, max=100),
|
||||
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
|
||||
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
@classmethod
|
||||
def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
|
||||
images=images,
|
||||
filename_prefix=filename_prefix,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
lossless=lossless,
|
||||
quality=quality,
|
||||
method=cls.COMPRESS_METHODS.get(method)
|
||||
)
|
||||
)
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
class SVG:
|
||||
"""
|
||||
Stores SVG representations via a list of BytesIO objects.
|
||||
"""
|
||||
def __init__(self, data: list[BytesIO]):
|
||||
self.data = data
|
||||
|
||||
def combine(self, other: 'SVG') -> 'SVG':
|
||||
return SVG(self.data + other.data)
|
||||
|
||||
@staticmethod
|
||||
def combine_all(svgs: list['SVG']) -> 'SVG':
|
||||
all_svgs_list: list[BytesIO] = []
|
||||
for svg_item in svgs:
|
||||
all_svgs_list.extend(svg_item.data)
|
||||
return SVG(all_svgs_list)
|
||||
save_images = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageStitch:
|
||||
class SaveAnimatedPNG(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAnimatedPNG",
|
||||
category="image/animation",
|
||||
inputs=[
|
||||
IO.Image.Input("images"),
|
||||
IO.String.Input("filename_prefix", default="ComfyUI"),
|
||||
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
||||
IO.Int.Input("compress_level", default=4, min=0, max=9),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
|
||||
images=images,
|
||||
filename_prefix=filename_prefix,
|
||||
cls=cls,
|
||||
fps=fps,
|
||||
compress_level=compress_level,
|
||||
)
|
||||
)
|
||||
|
||||
save_images = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageStitch(IO.ComfyNode):
|
||||
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",),
|
||||
"direction": (["right", "down", "left", "up"], {"default": "right"}),
|
||||
"match_image_size": ("BOOLEAN", {"default": True}),
|
||||
"spacing_width": (
|
||||
"INT",
|
||||
{"default": 0, "min": 0, "max": 1024, "step": 2},
|
||||
),
|
||||
"spacing_color": (
|
||||
["white", "black", "red", "green", "blue"],
|
||||
{"default": "white"},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"image2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageStitch",
|
||||
display_name="Image Stitch",
|
||||
description="Stitches image2 to image1 in the specified direction.\n"
|
||||
"If image2 is not provided, returns image1 unchanged.\n"
|
||||
"Optional spacing can be added between images.",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image1"),
|
||||
IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
|
||||
IO.Boolean.Input("match_image_size", default=True),
|
||||
IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
|
||||
IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
|
||||
IO.Image.Input("image2", optional=True),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stitch"
|
||||
CATEGORY = "image/transform"
|
||||
DESCRIPTION = """
|
||||
Stitches image2 to image1 in the specified direction.
|
||||
If image2 is not provided, returns image1 unchanged.
|
||||
Optional spacing can be added between images.
|
||||
"""
|
||||
|
||||
def stitch(
|
||||
self,
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image1,
|
||||
direction,
|
||||
match_image_size,
|
||||
spacing_width,
|
||||
spacing_color,
|
||||
image2=None,
|
||||
):
|
||||
) -> IO.NodeOutput:
|
||||
if image2 is None:
|
||||
return (image1,)
|
||||
return IO.NodeOutput(image1)
|
||||
|
||||
# Handle batch size differences
|
||||
if image1.shape[0] != image2.shape[0]:
|
||||
@ -412,36 +363,30 @@ Optional spacing can be added between images.
|
||||
images.insert(1, spacing)
|
||||
|
||||
concat_dim = 2 if direction in ["left", "right"] else 1
|
||||
return (torch.cat(images, dim=concat_dim),)
|
||||
return IO.NodeOutput(torch.cat(images, dim=concat_dim))
|
||||
|
||||
stitch = execute # TODO: remove
|
||||
|
||||
|
||||
class ResizeAndPadImage(IO.ComfyNode):
|
||||
|
||||
class ResizeAndPadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"target_width": ("INT", {
|
||||
"default": 512,
|
||||
"min": 1,
|
||||
"max": MAX_RESOLUTION,
|
||||
"step": 1
|
||||
}),
|
||||
"target_height": ("INT", {
|
||||
"default": 512,
|
||||
"min": 1,
|
||||
"max": MAX_RESOLUTION,
|
||||
"step": 1
|
||||
}),
|
||||
"padding_color": (["white", "black"],),
|
||||
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ResizeAndPadImage",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Combo.Input("padding_color", options=["white", "black"]),
|
||||
IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "resize_and_pad"
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
|
||||
@classmethod
|
||||
def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
|
||||
batch_size, orig_height, orig_width, channels = image.shape
|
||||
|
||||
scale_w = target_width / orig_width
|
||||
@ -469,52 +414,47 @@ class ResizeAndPadImage:
|
||||
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
|
||||
|
||||
output = padded.permute(0, 2, 3, 1)
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class SaveSVGNode:
|
||||
"""
|
||||
Save SVG files on disk.
|
||||
"""
|
||||
resize_and_pad = execute # TODO: remove
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
RETURN_TYPES = ()
|
||||
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
|
||||
FUNCTION = "save_svg"
|
||||
CATEGORY = "image/save" # Changed
|
||||
OUTPUT_NODE = True
|
||||
class SaveSVGNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"svg": ("SVG",), # Changed
|
||||
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO"
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveSVGNode",
|
||||
description="Save SVG files on disk.",
|
||||
category="image/save",
|
||||
inputs=[
|
||||
IO.SVG.Input("svg"),
|
||||
IO.String.Input(
|
||||
"filename_prefix",
|
||||
default="svg/ComfyUI",
|
||||
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
|
||||
),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results = list()
|
||||
@classmethod
|
||||
def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results: list[UI.SavedResult] = []
|
||||
|
||||
# Prepare metadata JSON
|
||||
metadata_dict = {}
|
||||
if prompt is not None:
|
||||
metadata_dict["prompt"] = prompt
|
||||
if extra_pnginfo is not None:
|
||||
metadata_dict.update(extra_pnginfo)
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata_dict["prompt"] = cls.hidden.prompt
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
metadata_dict.update(cls.hidden.extra_pnginfo)
|
||||
|
||||
# Convert metadata to JSON string
|
||||
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
||||
|
||||
|
||||
for batch_number, svg_bytes in enumerate(svg.data):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
||||
@ -544,57 +484,64 @@ class SaveSVGNode:
|
||||
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
||||
svg_file.write(svg_content.encode('utf-8'))
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
|
||||
counter += 1
|
||||
return { "ui": { "images": results } }
|
||||
return IO.NodeOutput(ui={"images": results})
|
||||
|
||||
class GetImageSize:
|
||||
save_svg = execute # TODO: remove
|
||||
|
||||
|
||||
class GetImageSize(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": (IO.IMAGE,),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GetImageSize",
|
||||
display_name="Get Image Size",
|
||||
description="Returns width and height of the image, and passes it through unchanged.",
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Int.Output(display_name="width"),
|
||||
IO.Int.Output(display_name="height"),
|
||||
IO.Int.Output(display_name="batch_size"),
|
||||
],
|
||||
hidden=[IO.Hidden.unique_id],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.INT, IO.INT, IO.INT)
|
||||
RETURN_NAMES = ("width", "height", "batch_size")
|
||||
FUNCTION = "get_size"
|
||||
|
||||
CATEGORY = "image"
|
||||
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
|
||||
|
||||
def get_size(self, image, unique_id=None) -> tuple[int, int]:
|
||||
@classmethod
|
||||
def execute(cls, image) -> IO.NodeOutput:
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
batch_size = image.shape[0]
|
||||
|
||||
# Send progress text to display size on the node
|
||||
if unique_id:
|
||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id)
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
|
||||
|
||||
return width, height, batch_size
|
||||
return IO.NodeOutput(width, height, batch_size)
|
||||
|
||||
get_size = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageRotate(IO.ComfyNode):
|
||||
|
||||
class ImageRotate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "rotate"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageRotate",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def rotate(self, image, rotation):
|
||||
@classmethod
|
||||
def execute(cls, image, rotation) -> IO.NodeOutput:
|
||||
rotate_by = 0
|
||||
if rotation.startswith("90"):
|
||||
rotate_by = 1
|
||||
@ -604,41 +551,57 @@ class ImageRotate:
|
||||
rotate_by = 3
|
||||
|
||||
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
|
||||
return (image,)
|
||||
return IO.NodeOutput(image)
|
||||
|
||||
rotate = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageFlip(IO.ComfyNode):
|
||||
|
||||
class ImageFlip:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": (IO.IMAGE,),
|
||||
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
|
||||
}}
|
||||
RETURN_TYPES = (IO.IMAGE,)
|
||||
FUNCTION = "flip"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageFlip",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def flip(self, image, flip_method):
|
||||
@classmethod
|
||||
def execute(cls, image, flip_method) -> IO.NodeOutput:
|
||||
if flip_method.startswith("x"):
|
||||
image = torch.flip(image, dims=[1])
|
||||
elif flip_method.startswith("y"):
|
||||
image = torch.flip(image, dims=[2])
|
||||
|
||||
return (image,)
|
||||
return IO.NodeOutput(image)
|
||||
|
||||
class ImageScaleToMaxDimension:
|
||||
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
|
||||
flip = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageScaleToMaxDimension(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"image": ("IMAGE",),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "upscale"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageScaleToMaxDimension",
|
||||
category="image/upscaling",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input(
|
||||
"upscale_method",
|
||||
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
|
||||
),
|
||||
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
|
||||
def upscale(self, image, upscale_method, largest_size):
|
||||
@classmethod
|
||||
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
|
||||
height = image.shape[1]
|
||||
width = image.shape[2]
|
||||
|
||||
@ -655,20 +618,30 @@ class ImageScaleToMaxDimension:
|
||||
samples = image.movedim(-1, 1)
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return (s,)
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"ImageAddNoise": ImageAddNoise,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
"SaveSVGNode": SaveSVGNode,
|
||||
"ImageStitch": ImageStitch,
|
||||
"ResizeAndPadImage": ResizeAndPadImage,
|
||||
"GetImageSize": GetImageSize,
|
||||
"ImageRotate": ImageRotate,
|
||||
"ImageFlip": ImageFlip,
|
||||
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
|
||||
}
|
||||
upscale = execute # TODO: remove
|
||||
|
||||
|
||||
class ImagesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ImageCrop,
|
||||
RepeatImageBatch,
|
||||
ImageFromBatch,
|
||||
ImageAddNoise,
|
||||
SaveAnimatedWEBP,
|
||||
SaveAnimatedPNG,
|
||||
SaveSVGNode,
|
||||
ImageStitch,
|
||||
ResizeAndPadImage,
|
||||
GetImageSize,
|
||||
ImageRotate,
|
||||
ImageFlip,
|
||||
ImageScaleToMaxDimension,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ImagesExtension:
|
||||
return ImagesExtension()
|
||||
|
||||
@ -5,6 +5,7 @@ import nodes
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
import math
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@ -207,12 +208,54 @@ class LatentCut(io.ComfyNode):
|
||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentCutToBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentCutToBatch",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("dim", options=["t", "x", "y"]),
|
||||
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
|
||||
if "x" in dim:
|
||||
dim = s1.ndim - 1
|
||||
elif "y" in dim:
|
||||
dim = s1.ndim - 2
|
||||
elif "t" in dim:
|
||||
dim = s1.ndim - 3
|
||||
|
||||
if dim < 2:
|
||||
return io.NodeOutput(samples)
|
||||
|
||||
s = s1.movedim(dim, 1)
|
||||
if s.shape[1] < slice_size:
|
||||
slice_size = s.shape[1]
|
||||
elif s.shape[1] % slice_size != 0:
|
||||
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
|
||||
new_shape = [-1, slice_size] + list(s.shape[2:])
|
||||
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
class LatentBatch(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentBatch",
|
||||
category="latent/batch",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
@ -435,6 +478,7 @@ class LatentExtension(ComfyExtension):
|
||||
LatentInterpolate,
|
||||
LatentConcat,
|
||||
LatentCut,
|
||||
LatentCutToBatch,
|
||||
LatentBatch,
|
||||
LatentBatchSeedBehavior,
|
||||
LatentApplyOperation,
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypedDict
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.latest import _io
|
||||
|
||||
# sentinel for missing inputs
|
||||
MISSING = object()
|
||||
|
||||
|
||||
class SwitchNode(io.ComfyNode):
|
||||
@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode):
|
||||
display_name="Switch",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
io.MatchType.Input("on_false", template=template, lazy=True),
|
||||
io.MatchType.Input("on_true", template=template, lazy=True),
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(template=template, display_name="output"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, switch, on_false=None, on_true=None):
|
||||
if switch and on_true is None:
|
||||
return ["on_true"]
|
||||
if not switch and on_false is None:
|
||||
return ["on_false"]
|
||||
|
||||
@classmethod
|
||||
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
|
||||
return io.NodeOutput(on_true if switch else on_false)
|
||||
|
||||
|
||||
class SoftSwitchNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("switch")
|
||||
return io.Schema(
|
||||
node_id="ComfySoftSwitchNode",
|
||||
display_name="Soft Switch",
|
||||
category="logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
||||
@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_lazy_status(cls, switch, on_false=..., on_true=...):
|
||||
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||
# We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
|
||||
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
||||
|
||||
# One of the inputs may be missing, in which case we need to evaluate the other input
|
||||
if on_false is ...:
|
||||
if on_false is MISSING:
|
||||
return ["on_true"]
|
||||
if on_true is ...:
|
||||
if on_true is MISSING:
|
||||
return ["on_false"]
|
||||
# Normal lazy switch operation
|
||||
if switch and on_true is None:
|
||||
@ -41,22 +75,50 @@ class SwitchNode(io.ComfyNode):
|
||||
return ["on_false"]
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, switch, on_false=..., on_true=...):
|
||||
def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
|
||||
# This check happens before check_lazy_status(), so we can eliminate the case where
|
||||
# both inputs are missing.
|
||||
if on_false is ... and on_true is ...:
|
||||
if on_false is MISSING and on_true is MISSING:
|
||||
return "At least one of on_false or on_true must be connected to Switch node"
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
|
||||
if on_true is ...:
|
||||
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
|
||||
if on_true is MISSING:
|
||||
return io.NodeOutput(on_false)
|
||||
if on_false is ...:
|
||||
if on_false is MISSING:
|
||||
return io.NodeOutput(on_true)
|
||||
return io.NodeOutput(on_true if switch else on_false)
|
||||
|
||||
|
||||
class CustomComboNode(io.ComfyNode):
|
||||
"""
|
||||
Frontend node that allows user to write their own options for a combo.
|
||||
This is here to make sure the node has a backend-representation to avoid some annoyances.
|
||||
"""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CustomCombo",
|
||||
display_name="Custom Combo",
|
||||
category="utils",
|
||||
is_experimental=True,
|
||||
inputs=[io.Combo.Input("choice", options=[])],
|
||||
outputs=[io.String.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
|
||||
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
|
||||
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
|
||||
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice)
|
||||
|
||||
|
||||
class DCTestNode(io.ComfyNode):
|
||||
class DCValues(TypedDict):
|
||||
combo: str
|
||||
@ -72,14 +134,14 @@ class DCTestNode(io.ComfyNode):
|
||||
display_name="DCTest",
|
||||
category="logic",
|
||||
is_output_node=True,
|
||||
inputs=[_io.DynamicCombo.Input("combo", options=[
|
||||
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||
_io.DynamicCombo.Option("option4", [
|
||||
_io.DynamicCombo.Input("subcombo", options=[
|
||||
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
||||
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
||||
io.DynamicCombo.Option("option4", [
|
||||
io.DynamicCombo.Input("subcombo", options=[
|
||||
io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
||||
io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
||||
])
|
||||
])]
|
||||
)],
|
||||
@ -141,14 +203,65 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
||||
combined = ",".join([str(x) for x in vals])
|
||||
return io.NodeOutput(combined)
|
||||
|
||||
class ComboOutputTestNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ComboOptionTestNode",
|
||||
display_name="ComboOptionTest",
|
||||
category="logic",
|
||||
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(combo, combo2)
|
||||
|
||||
class ConvertStringToComboNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ConvertStringToComboNode",
|
||||
display_name="Convert String to Combo",
|
||||
category="logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
outputs=[io.Combo.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, string: str) -> io.NodeOutput:
|
||||
return io.NodeOutput(string)
|
||||
|
||||
class InvertBooleanNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InvertBooleanNode",
|
||||
display_name="Invert Boolean",
|
||||
category="logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, boolean: bool) -> io.NodeOutput:
|
||||
return io.NodeOutput(not boolean)
|
||||
|
||||
class LogicExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
# SwitchNode,
|
||||
SwitchNode,
|
||||
CustomComboNode,
|
||||
# SoftSwitchNode,
|
||||
# ConvertStringToComboNode,
|
||||
# DCTestNode,
|
||||
# AutogrowNamesTestNode,
|
||||
# AutogrowPrefixTestNode,
|
||||
# ComboOutputTestNode,
|
||||
# InvertBooleanNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> LogicExtension:
|
||||
|
||||
@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
||||
display_name="Mahiro CFG",
|
||||
category="_for_testing",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
|
||||
@ -348,7 +348,7 @@ class ZImageControlPatch:
|
||||
if self.mask is None:
|
||||
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
||||
else:
|
||||
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
||||
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
||||
|
||||
if latent_image is None:
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
||||
|
||||
@ -4,11 +4,15 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Literal
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_extras.nodes_latent import reshape_latent_to
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class Blend(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -221,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
@ -228,18 +233,365 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
|
||||
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
||||
samples = image.movedim(-1,1)
|
||||
total = int(megapixels * 1024 * 1024)
|
||||
total = megapixels * 1024 * 1024
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
|
||||
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
|
||||
s = s.movedim(1,-1)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
class ResizeType(str, Enum):
|
||||
SCALE_BY = "scale by multiplier"
|
||||
SCALE_DIMENSIONS = "scale dimensions"
|
||||
SCALE_LONGER_DIMENSION = "scale longer dimension"
|
||||
SCALE_SHORTER_DIMENSION = "scale shorter dimension"
|
||||
SCALE_WIDTH = "scale width"
|
||||
SCALE_HEIGHT = "scale height"
|
||||
SCALE_TOTAL_PIXELS = "scale total pixels"
|
||||
MATCH_SIZE = "match size"
|
||||
|
||||
def is_image(input: torch.Tensor) -> bool:
|
||||
# images have 4 dimensions: [batch, height, width, channels]
|
||||
# masks have 3 dimensions: [batch, height, width]
|
||||
return len(input.shape) == 4
|
||||
|
||||
def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||
if is_type_image:
|
||||
input = input.movedim(-1, 1)
|
||||
else:
|
||||
input = input.unsqueeze(1)
|
||||
return input
|
||||
|
||||
def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor:
|
||||
if is_type_image:
|
||||
input = input.movedim(1, -1)
|
||||
else:
|
||||
input = input.squeeze(1)
|
||||
return input
|
||||
|
||||
def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = round(input.shape[-1] * multiplier)
|
||||
height = round(input.shape[-2] * multiplier)
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor:
|
||||
if width == 0 and height == 0:
|
||||
return input
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
|
||||
if width == 0:
|
||||
width = max(1, round(input.shape[-1] * height / input.shape[-2]))
|
||||
elif height == 0:
|
||||
height = max(1, round(input.shape[-2] * width / input.shape[-1]))
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = input.shape[-1]
|
||||
height = input.shape[-2]
|
||||
|
||||
if height > width:
|
||||
width = round((width / height) * longer_size)
|
||||
height = longer_size
|
||||
elif width > height:
|
||||
height = round((height / width) * longer_size)
|
||||
width = longer_size
|
||||
else:
|
||||
height = longer_size
|
||||
width = longer_size
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
width = input.shape[-1]
|
||||
height = input.shape[-2]
|
||||
|
||||
if height < width:
|
||||
width = round((width / height) * shorter_size)
|
||||
height = shorter_size
|
||||
elif width > height:
|
||||
height = round((height / width) * shorter_size)
|
||||
width = shorter_size
|
||||
else:
|
||||
height = shorter_size
|
||||
width = shorter_size
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
total = int(megapixels * 1024 * 1024)
|
||||
|
||||
scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2]))
|
||||
width = round(input.shape[-1] * scale_by)
|
||||
height = round(input.shape[-2] * scale_by)
|
||||
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled")
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor:
|
||||
is_type_image = is_image(input)
|
||||
input = init_image_mask_input(input, is_type_image)
|
||||
match = init_image_mask_input(match, is_image(match))
|
||||
|
||||
width = match.shape[-1]
|
||||
height = match.shape[-2]
|
||||
input = comfy.utils.common_upscale(input, width, height, scale_method, crop)
|
||||
input = finalize_image_mask_input(input, is_type_image)
|
||||
return input
|
||||
|
||||
class ResizeImageMaskNode(io.ComfyNode):
|
||||
|
||||
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
class ResizeTypedDict(TypedDict):
|
||||
resize_type: ResizeType
|
||||
scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
crop: Literal["disabled", "center"]
|
||||
multiplier: float
|
||||
width: int
|
||||
height: int
|
||||
longer_size: int
|
||||
shorter_size: int
|
||||
megapixels: float
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
|
||||
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
|
||||
return io.Schema(
|
||||
node_id="ResizeImageMaskNode",
|
||||
display_name="Resize Image/Mask",
|
||||
category="transform",
|
||||
inputs=[
|
||||
io.MatchType.Input("input", template=template),
|
||||
io.DynamicCombo.Input("resize_type", options=[
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
|
||||
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
crop_combo,
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
|
||||
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
|
||||
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
|
||||
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
|
||||
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
]),
|
||||
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
|
||||
io.MultiType.Input("match", [io.Image, io.Mask]),
|
||||
crop_combo,
|
||||
]),
|
||||
]),
|
||||
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
|
||||
],
|
||||
outputs=[io.MatchType.Output(template=template, display_name="resized")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput:
|
||||
selected_type = resize_type["resize_type"]
|
||||
if selected_type == ResizeType.SCALE_BY:
|
||||
return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_DIMENSIONS:
|
||||
return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"]))
|
||||
elif selected_type == ResizeType.SCALE_LONGER_DIMENSION:
|
||||
return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION:
|
||||
return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_WIDTH:
|
||||
return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method))
|
||||
elif selected_type == ResizeType.SCALE_HEIGHT:
|
||||
return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method))
|
||||
elif selected_type == ResizeType.SCALE_TOTAL_PIXELS:
|
||||
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
|
||||
elif selected_type == ResizeType.MATCH_SIZE:
|
||||
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
||||
raise ValueError(f"Unsupported resize type: {selected_type}")
|
||||
|
||||
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
|
||||
if len(images) == 0:
|
||||
return None
|
||||
# first, get the max channels count
|
||||
max_channels = max(image.shape[-1] for image in images)
|
||||
# then, pad all images to have the same channels count
|
||||
padded_images: list[torch.Tensor] = []
|
||||
for image in images:
|
||||
if image.shape[-1] < max_channels:
|
||||
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
|
||||
else:
|
||||
padded_images.append(image)
|
||||
# resize all images to be the same size as the first image
|
||||
resized_images: list[torch.Tensor] = []
|
||||
first_image_shape = padded_images[0].shape
|
||||
for image in padded_images:
|
||||
if image.shape[1:] != first_image_shape[1:]:
|
||||
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
|
||||
else:
|
||||
resized_images.append(image)
|
||||
# batch the images in the format [b, h, w, c]
|
||||
return torch.cat(resized_images, dim=0)
|
||||
|
||||
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
|
||||
if len(masks) == 0:
|
||||
return None
|
||||
# resize all masks to be the same size as the first mask
|
||||
resized_masks: list[torch.Tensor] = []
|
||||
first_mask_shape = masks[0].shape
|
||||
for mask in masks:
|
||||
if mask.shape[1:] != first_mask_shape[1:]:
|
||||
mask = init_image_mask_input(mask, is_type_image=False)
|
||||
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
|
||||
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
|
||||
else:
|
||||
resized_masks.append(mask)
|
||||
# batch the masks in the format [b, h, w]
|
||||
return torch.cat(resized_masks, dim=0)
|
||||
|
||||
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
|
||||
if len(latents) == 0:
|
||||
return None
|
||||
samples_out = latents[0].copy()
|
||||
samples_out["batch_index"] = []
|
||||
first_samples = latents[0]["samples"]
|
||||
tensors: list[torch.Tensor] = []
|
||||
for latent in latents:
|
||||
# first, deal with latent tensors
|
||||
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
|
||||
# next, deal with batch_index
|
||||
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
|
||||
samples_out["samples"] = torch.cat(tensors, dim=0)
|
||||
return samples_out
|
||||
|
||||
class BatchImagesNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesNode",
|
||||
display_name="Batch Images",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Autogrow.Input("images", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_images(list(images.values())))
|
||||
|
||||
class BatchMasksNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchMasksNode",
|
||||
display_name="Batch Masks",
|
||||
category="mask",
|
||||
inputs=[
|
||||
io.Autogrow.Input("masks", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_masks(list(masks.values())))
|
||||
|
||||
class BatchLatentsNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchLatentsNode",
|
||||
display_name="Batch Latents",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.Autogrow.Input("latents", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(batch_latents(list(latents.values())))
|
||||
|
||||
class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||
io.MatchType.Input("input", matchtype_template),
|
||||
prefix="input", min=1, max=50)
|
||||
return io.Schema(
|
||||
node_id="BatchImagesMasksLatentsNode",
|
||||
display_name="Batch Images/Masks/Latents",
|
||||
category="util",
|
||||
inputs=[
|
||||
io.Autogrow.Input("inputs", template=autogrow_template)
|
||||
],
|
||||
outputs=[
|
||||
io.MatchType.Output(id=None, template=matchtype_template)
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||
batched = None
|
||||
values = list(inputs.values())
|
||||
# latents
|
||||
if isinstance(values[0], dict):
|
||||
batched = batch_latents(values)
|
||||
# images
|
||||
elif is_image(values[0]):
|
||||
batched = batch_images(values)
|
||||
# masks
|
||||
else:
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -249,6 +601,11 @@ class PostProcessingExtension(ComfyExtension):
|
||||
Quantize,
|
||||
Sharpen,
|
||||
ImageScaleToTotalPixels,
|
||||
ResizeImageMaskNode,
|
||||
BatchImagesNode,
|
||||
BatchMasksNode,
|
||||
BatchLatentsNode,
|
||||
# BatchImagesMasksLatentsNode,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> PostProcessingExtension:
|
||||
|
||||
@ -66,7 +66,7 @@ class Float(io.ComfyNode):
|
||||
display_name="Float",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
|
||||
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1),
|
||||
],
|
||||
outputs=[io.Float.Output()],
|
||||
)
|
||||
|
||||
@ -3,7 +3,9 @@ import comfy.utils
|
||||
import math
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyQwenImageLayeredLatentImage",
|
||||
display_name="Empty Qwen Image Layered Latent",
|
||||
category="latent/qwen",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
class QwenExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeQwenImageEdit,
|
||||
TextEncodeQwenImageEditPlus,
|
||||
EmptyQwenImageLayeredLatentImage,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
|
||||
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
"""
|
||||
CFGGuider with modifications for training specific logic
|
||||
"""
|
||||
def outer_sample(
|
||||
self,
|
||||
noise,
|
||||
latent_image,
|
||||
sampler,
|
||||
sigmas,
|
||||
denoise_mask=None,
|
||||
callback=None,
|
||||
disable_pbar=False,
|
||||
seed=None,
|
||||
latent_shapes=None,
|
||||
):
|
||||
self.inner_model, self.conds, self.loaded_models = (
|
||||
comfy.sampler_helpers.prepare_sampling(
|
||||
self.model_patcher,
|
||||
noise.shape,
|
||||
self.conds,
|
||||
self.model_options,
|
||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
||||
)
|
||||
)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
denoise_mask = comfy.sampler_helpers.prepare_mask(
|
||||
denoise_mask, noise.shape, device
|
||||
)
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
sigmas = sigmas.to(device)
|
||||
comfy.samplers.cast_to_load_options(
|
||||
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
|
||||
)
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(
|
||||
noise,
|
||||
latent_image,
|
||||
device,
|
||||
sampler,
|
||||
sigmas,
|
||||
denoise_mask,
|
||||
callback,
|
||||
disable_pbar,
|
||||
seed,
|
||||
latent_shapes=latent_shapes,
|
||||
)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
del self.loaded_models
|
||||
return output
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
new_dict = {}
|
||||
for k, v in d.items():
|
||||
@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
seed=0,
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||
# Bucket mode data
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
if bucket_latents is not None:
|
||||
self._init_bucket_data(bucket_latents)
|
||||
else:
|
||||
self.bucket_offsets = None
|
||||
self.bucket_weights = None
|
||||
self.num_images = None
|
||||
|
||||
def _init_bucket_data(self, bucket_latents):
|
||||
"""Initialize bucket offsets and weights for sampling."""
|
||||
self.bucket_offsets = [0]
|
||||
bucket_sizes = []
|
||||
for lat in bucket_latents:
|
||||
bucket_sizes.append(lat.shape[0])
|
||||
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
||||
self.num_images = self.bucket_offsets[-1]
|
||||
# Weights for sampling buckets proportional to their size
|
||||
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
||||
|
||||
def fwd_bwd(
|
||||
self,
|
||||
@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||
"""Generate random sigma values for a batch."""
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
return torch.tensor(batch_sigmas).to(device)
|
||||
|
||||
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
|
||||
"""Execute one training step in bucket mode."""
|
||||
# Sample bucket (weighted by size), then sample batch from bucket
|
||||
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
||||
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
||||
bucket_size = bucket_latent.shape[0]
|
||||
bucket_offset = self.bucket_offsets[bucket_idx]
|
||||
|
||||
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
||||
actual_batch_size = min(self.batch_size, bucket_size)
|
||||
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
||||
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
||||
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
||||
|
||||
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
|
||||
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond, # Use flattened cond with absolute indices
|
||||
absolute_indices,
|
||||
extra_args,
|
||||
self.num_images,
|
||||
bwd=True,
|
||||
)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
||||
|
||||
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
|
||||
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond,
|
||||
indicies,
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=True,
|
||||
)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
||||
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
total_loss = 0
|
||||
for index in indicies:
|
||||
single_latent = self.real_dataset[index].to(latent_image)
|
||||
batch_noise = noisegen.generate_noise(
|
||||
{"samples": single_latent}
|
||||
).to(single_latent.device)
|
||||
batch_sigmas = (
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
)
|
||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
single_latent,
|
||||
cond,
|
||||
[index],
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=False,
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
|
||||
def sample(
|
||||
self,
|
||||
model_wrap,
|
||||
@ -142,70 +330,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
||||
self.seed + i * 1000
|
||||
)
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
|
||||
if self.real_dataset is None:
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond,
|
||||
indicies,
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=True,
|
||||
)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
if self.bucket_latents is not None:
|
||||
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
|
||||
elif self.real_dataset is None:
|
||||
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
else:
|
||||
total_loss = 0
|
||||
for index in indicies:
|
||||
single_latent = self.real_dataset[index].to(latent_image)
|
||||
batch_noise = noisegen.generate_noise(
|
||||
{"samples": single_latent}
|
||||
).to(single_latent.device)
|
||||
batch_sigmas = (
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
)
|
||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
single_latent,
|
||||
cond,
|
||||
[index],
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=False,
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
ui_pbar.update(1)
|
||||
torch.cuda.empty_cache()
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
@ -283,6 +419,364 @@ def unpatch(m):
|
||||
del m.org_forward
|
||||
|
||||
|
||||
def _process_latents_bucket_mode(latents):
|
||||
"""Process latents for bucket mode training.
|
||||
|
||||
Args:
|
||||
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
||||
|
||||
Returns:
|
||||
list of latent tensors
|
||||
"""
|
||||
bucket_latents = []
|
||||
for latent_dict in latents:
|
||||
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
||||
return bucket_latents
|
||||
|
||||
|
||||
def _process_latents_standard_mode(latents):
|
||||
"""Process latents for standard (non-bucket) mode training.
|
||||
|
||||
Args:
|
||||
latents: list of latent dicts or single latent dict
|
||||
|
||||
Returns:
|
||||
Processed latents (tensor or list of tensors)
|
||||
"""
|
||||
if len(latents) == 1:
|
||||
return latents[0]["samples"] # Single latent dict
|
||||
|
||||
latent_list = []
|
||||
for latent in latents:
|
||||
latent = latent["samples"]
|
||||
bs = latent.shape[0]
|
||||
if bs != 1:
|
||||
for sub_latent in latent:
|
||||
latent_list.append(sub_latent[None])
|
||||
else:
|
||||
latent_list.append(latent)
|
||||
return latent_list
|
||||
|
||||
|
||||
def _process_conditioning(positive):
|
||||
"""Process conditioning - either single list or list of lists.
|
||||
|
||||
Args:
|
||||
positive: list of conditioning
|
||||
|
||||
Returns:
|
||||
Flattened conditioning list
|
||||
"""
|
||||
if len(positive) == 1:
|
||||
return positive[0] # Single conditioning list
|
||||
|
||||
# Multiple conditioning lists - flatten
|
||||
flat_positive = []
|
||||
for cond in positive:
|
||||
if isinstance(cond, list):
|
||||
flat_positive.extend(cond)
|
||||
else:
|
||||
flat_positive.append(cond)
|
||||
return flat_positive
|
||||
|
||||
|
||||
def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
||||
"""Convert latents to dtype and compute image counts.
|
||||
|
||||
Args:
|
||||
latents: Latents (tensor, list of tensors, or bucket list)
|
||||
dtype: Target dtype
|
||||
bucket_mode: Whether bucket mode is enabled
|
||||
|
||||
Returns:
|
||||
tuple: (processed_latents, num_images, multi_res)
|
||||
"""
|
||||
if bucket_mode:
|
||||
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
num_buckets = len(latents)
|
||||
num_images = sum(t.shape[0] for t in latents)
|
||||
multi_res = False # Not using multi_res path in bucket mode
|
||||
|
||||
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||
for i, lat in enumerate(latents):
|
||||
logging.info(f" Bucket {i}: shape {lat.shape}")
|
||||
return latents, num_images, multi_res
|
||||
|
||||
# Non-bucket mode
|
||||
if isinstance(latents, list):
|
||||
all_shapes = set()
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
multi_res = True
|
||||
else:
|
||||
multi_res = False
|
||||
latents = torch.cat(latents, dim=0)
|
||||
num_images = len(latents)
|
||||
elif isinstance(latents, torch.Tensor):
|
||||
latents = latents.to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
multi_res = False
|
||||
else:
|
||||
logging.error(f"Invalid latents type: {type(latents)}")
|
||||
num_images = 0
|
||||
multi_res = False
|
||||
|
||||
return latents, num_images, multi_res
|
||||
|
||||
|
||||
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
||||
"""Validate conditioning count matches image count, expand if needed.
|
||||
|
||||
Args:
|
||||
positive: Conditioning list
|
||||
num_images: Number of images
|
||||
bucket_mode: Whether bucket mode is enabled
|
||||
|
||||
Returns:
|
||||
Validated/expanded conditioning list
|
||||
|
||||
Raises:
|
||||
ValueError: If conditioning count doesn't match image count
|
||||
"""
|
||||
if bucket_mode:
|
||||
return positive # Skip validation in bucket mode
|
||||
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
return positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
raise ValueError(
|
||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||
)
|
||||
return positive
|
||||
|
||||
|
||||
def _load_existing_lora(existing_lora):
|
||||
"""Load existing LoRA weights if provided.
|
||||
|
||||
Args:
|
||||
existing_lora: LoRA filename or "[None]"
|
||||
|
||||
Returns:
|
||||
tuple: (existing_weights dict, existing_steps int)
|
||||
"""
|
||||
if existing_lora == "[None]":
|
||||
return {}, 0
|
||||
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
existing_weights = {}
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
return existing_weights, existing_steps
|
||||
|
||||
|
||||
def _create_weight_adapter(
|
||||
module, module_name, existing_weights, algorithm, lora_dtype, rank
|
||||
):
|
||||
"""Create a weight adapter for a module with weight.
|
||||
|
||||
Args:
|
||||
module: The module to create adapter for
|
||||
module_name: Name of the module
|
||||
existing_weights: Dict of existing LoRA weights
|
||||
algorithm: Algorithm name for new adapters
|
||||
lora_dtype: dtype for LoRA weights
|
||||
rank: Rank for new LoRA adapters
|
||||
|
||||
Returns:
|
||||
tuple: (train_adapter, lora_params dict)
|
||||
"""
|
||||
key = f"{module_name}.weight"
|
||||
shape = module.weight.shape
|
||||
lora_params = {}
|
||||
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||
|
||||
# Try to load existing adapter
|
||||
existing_adapter = None
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
module_name, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
|
||||
if existing_adapter is None:
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
module.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_params[f"{module_name}.{name}"] = parameter
|
||||
|
||||
return train_adapter.train().requires_grad_(True), lora_params
|
||||
else:
|
||||
# 1D weight - use BiasDiff
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
diff_module = BiasDiff(diff).train().requires_grad_(True)
|
||||
lora_params[f"{module_name}.diff"] = diff
|
||||
return diff_module, lora_params
|
||||
|
||||
|
||||
def _create_bias_adapter(module, module_name, lora_dtype):
|
||||
"""Create a bias adapter for a module with bias.
|
||||
|
||||
Args:
|
||||
module: The module with bias
|
||||
module_name: Name of the module
|
||||
lora_dtype: dtype for LoRA weights
|
||||
|
||||
Returns:
|
||||
tuple: (bias_module, lora_params dict)
|
||||
"""
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias).train().requires_grad_(True)
|
||||
lora_params = {f"{module_name}.diff_b": bias}
|
||||
return bias_module, lora_params
|
||||
|
||||
|
||||
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||
"""Setup all LoRA adapters on the model.
|
||||
|
||||
Args:
|
||||
mp: Model patcher
|
||||
existing_weights: Dict of existing LoRA weights
|
||||
algorithm: Algorithm name for new adapters
|
||||
lora_dtype: dtype for LoRA weights
|
||||
rank: Rank for new LoRA adapters
|
||||
|
||||
Returns:
|
||||
tuple: (lora_sd dict, all_weight_adapters list)
|
||||
"""
|
||||
lora_sd = {}
|
||||
all_weight_adapters = []
|
||||
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
adapter, params = _create_weight_adapter(
|
||||
m, n, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
lora_sd.update(params)
|
||||
key = f"{n}.weight"
|
||||
mp.add_weight_wrapper(key, adapter)
|
||||
all_weight_adapters.append(adapter)
|
||||
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
||||
lora_sd.update(bias_params)
|
||||
key = f"{n}.bias"
|
||||
mp.add_weight_wrapper(key, bias_adapter)
|
||||
all_weight_adapters.append(bias_adapter)
|
||||
|
||||
return lora_sd, all_weight_adapters
|
||||
|
||||
|
||||
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
||||
"""Create optimizer based on name.
|
||||
|
||||
Args:
|
||||
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
|
||||
parameters: Parameters to optimize
|
||||
learning_rate: Learning rate
|
||||
|
||||
Returns:
|
||||
Optimizer instance
|
||||
"""
|
||||
if optimizer_name == "Adam":
|
||||
return torch.optim.Adam(parameters, lr=learning_rate)
|
||||
elif optimizer_name == "AdamW":
|
||||
return torch.optim.AdamW(parameters, lr=learning_rate)
|
||||
elif optimizer_name == "SGD":
|
||||
return torch.optim.SGD(parameters, lr=learning_rate)
|
||||
elif optimizer_name == "RMSprop":
|
||||
return torch.optim.RMSprop(parameters, lr=learning_rate)
|
||||
|
||||
|
||||
def _create_loss_function(loss_function_name):
|
||||
"""Create loss function based on name.
|
||||
|
||||
Args:
|
||||
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
|
||||
|
||||
Returns:
|
||||
Loss function instance
|
||||
"""
|
||||
if loss_function_name == "MSE":
|
||||
return torch.nn.MSELoss()
|
||||
elif loss_function_name == "L1":
|
||||
return torch.nn.L1Loss()
|
||||
elif loss_function_name == "Huber":
|
||||
return torch.nn.HuberLoss()
|
||||
elif loss_function_name == "SmoothL1":
|
||||
return torch.nn.SmoothL1Loss()
|
||||
|
||||
|
||||
def _run_training_loop(
|
||||
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
||||
):
|
||||
"""Execute the training loop.
|
||||
|
||||
Args:
|
||||
guider: The guider object
|
||||
train_sampler: The training sampler
|
||||
latents: Latent tensors
|
||||
num_images: Number of images
|
||||
seed: Random seed
|
||||
bucket_mode: Whether bucket mode is enabled
|
||||
multi_res: Whether multi-resolution mode is enabled
|
||||
"""
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
|
||||
if bucket_mode:
|
||||
# Use first bucket's first latent as dummy for guider
|
||||
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": dummy_latent}),
|
||||
dummy_latent,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
)
|
||||
elif multi_res:
|
||||
# use first latent as dummy latent if multi_res
|
||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
)
|
||||
else:
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
)
|
||||
|
||||
|
||||
class TrainLoraNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default="[None]",
|
||||
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"bucket_mode",
|
||||
default=False,
|
||||
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(
|
||||
@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
bucket_mode,
|
||||
):
|
||||
# Extract scalars from lists (due to is_input_list=True)
|
||||
model = model[0]
|
||||
@ -427,215 +927,125 @@ class TrainLoraNode(io.ComfyNode):
|
||||
grad_accumulation_steps = grad_accumulation_steps[0]
|
||||
learning_rate = learning_rate[0]
|
||||
rank = rank[0]
|
||||
optimizer = optimizer[0]
|
||||
loss_function = loss_function[0]
|
||||
optimizer_name = optimizer[0]
|
||||
loss_function_name = loss_function[0]
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
existing_lora = existing_lora[0]
|
||||
bucket_mode = bucket_mode[0]
|
||||
|
||||
# Handle latents - either single dict or list of dicts
|
||||
if len(latents) == 1:
|
||||
latents = latents[0]["samples"] # Single latent dict
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
latents = _process_latents_bucket_mode(latents)
|
||||
else:
|
||||
latent_list = []
|
||||
for latent in latents:
|
||||
latent = latent["samples"]
|
||||
bs = latent.shape[0]
|
||||
if bs != 1:
|
||||
for sub_latent in latent:
|
||||
latent_list.append(sub_latent[None])
|
||||
else:
|
||||
latent_list.append(latent)
|
||||
latents = latent_list
|
||||
latents = _process_latents_standard_mode(latents)
|
||||
|
||||
# Handle conditioning - either single list or list of lists
|
||||
if len(positive) == 1:
|
||||
positive = positive[0] # Single conditioning list
|
||||
else:
|
||||
# Multiple conditioning lists - flatten
|
||||
flat_positive = []
|
||||
for cond in positive:
|
||||
if isinstance(cond, list):
|
||||
flat_positive.extend(cond)
|
||||
else:
|
||||
flat_positive.append(cond)
|
||||
positive = flat_positive
|
||||
# Process conditioning
|
||||
positive = _process_conditioning(positive)
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
# latents here can be list of different size latent or one large batch
|
||||
if isinstance(latents, list):
|
||||
all_shapes = set()
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
multi_res = True
|
||||
else:
|
||||
multi_res = False
|
||||
latents = torch.cat(latents, dim=0)
|
||||
num_images = len(latents)
|
||||
elif isinstance(latents, torch.Tensor):
|
||||
latents = latents.to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
else:
|
||||
logging.error(f"Invalid latents type: {type(latents)}")
|
||||
# Prepare latents and compute counts
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
)
|
||||
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
positive = positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
raise ValueError(
|
||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||
)
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights = {}
|
||||
existing_steps = 0
|
||||
if existing_lora != "[None]":
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
all_weight_adapters = []
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
key = "{}.weight".format(n)
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
existing_adapter = None
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
# Setup LoRA adapters
|
||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||
mp, existing_weights, algorithm, lora_dtype, rank
|
||||
)
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(
|
||||
lora_dtype
|
||||
)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
m.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_sd[f"{n}.{name}"] = parameter
|
||||
# Create optimizer and loss function
|
||||
optimizer = _create_optimizer(
|
||||
optimizer_name, lora_sd.values(), learning_rate
|
||||
)
|
||||
criterion = _create_loss_function(loss_function_name)
|
||||
|
||||
mp.add_weight_wrapper(key, train_adapter)
|
||||
all_weight_adapters.append(train_adapter)
|
||||
else:
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||
all_weight_adapters.append(diff_module)
|
||||
lora_sd["{}.diff".format(n)] = diff
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.bias.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||
all_weight_adapters.append(bias_module)
|
||||
|
||||
if optimizer == "Adam":
|
||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "AdamW":
|
||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "SGD":
|
||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "RMSprop":
|
||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||
|
||||
# Setup loss function based on selection
|
||||
if loss_function == "MSE":
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif loss_function == "L1":
|
||||
criterion = torch.nn.L1Loss()
|
||||
elif loss_function == "Huber":
|
||||
criterion = torch.nn.HuberLoss()
|
||||
elif loss_function == "SmoothL1":
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
# Setup gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(
|
||||
mp.model.diffusion_model
|
||||
):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
# With force_full_load=False we should be able to have offloading
|
||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||
comfy.model_management.load_models_gpu(
|
||||
[mp], memory_required=1e20, force_full_load=True
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
# Setup loss tracking
|
||||
loss_map = {"loss": []}
|
||||
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
# Create sampler
|
||||
if bucket_mode:
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
# Setup guider
|
||||
guider = TrainGuider(mp)
|
||||
guider.set_conds(positive)
|
||||
|
||||
# Run training loop
|
||||
try:
|
||||
# Generate dummy sigmas and noise
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
if multi_res:
|
||||
# use first latent as dummy latent if multi_res
|
||||
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
_run_training_loop(
|
||||
guider,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed,
|
||||
latents,
|
||||
num_images,
|
||||
seed,
|
||||
bucket_mode,
|
||||
multi_res,
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
|
||||
# Finalize adapters
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
@ -645,7 +1055,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
class LoraModelLoader(io.ComfyNode):#
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
|
||||
@ -817,7 +817,7 @@ def get_sample_indices(original_fps,
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
if fixed_start is not None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.4.0"
|
||||
__version__ = "0.7.0"
|
||||
|
||||
46
execution.py
46
execution.py
@ -79,7 +79,7 @@ class IsChangedCache:
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except Exception as e:
|
||||
@ -148,13 +148,12 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
v3_data: io.V3Data = {}
|
||||
hidden_inputs_v3 = {}
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
if is_v3:
|
||||
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||
else:
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs)
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
hidden_inputs_v3 = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||
@ -180,18 +179,18 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if is_v3:
|
||||
if schema.hidden:
|
||||
if io.Hidden.prompt in schema.hidden:
|
||||
if hidden is not None:
|
||||
if io.Hidden.prompt.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||
if io.Hidden.dynprompt in schema.hidden:
|
||||
if io.Hidden.dynprompt.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||
if io.Hidden.extra_pnginfo.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||
if io.Hidden.unique_id in schema.hidden:
|
||||
if io.Hidden.unique_id.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||
if io.Hidden.auth_token_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
@ -258,7 +257,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
pre_execute_cb(index)
|
||||
# V3
|
||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||
# if is just a class, then assign no resources or state, just create clone
|
||||
# if is just a class, then assign no state, just create clone
|
||||
if is_class(obj):
|
||||
type_obj = obj
|
||||
obj.VALIDATE_CLASS()
|
||||
@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
|
||||
# for check_lazy_status, the returned data should include the original key of the input
|
||||
v3_data_lazy = v3_data.copy()
|
||||
v3_data_lazy["create_dynamic_tuple"] = True
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
@ -599,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
@ -756,10 +759,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
v3_data = None
|
||||
validate_function_inputs = []
|
||||
validate_has_kwargs = False
|
||||
if issubclass(obj_class, _ComfyNodeInternal):
|
||||
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
|
||||
obj_class: _io._ComfyNodeBaseInternal
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
|
||||
validate_function_name = "validate_inputs"
|
||||
validate_function = first_real_override(obj_class, validate_function_name)
|
||||
else:
|
||||
@ -779,10 +785,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
assert extra_info is not None
|
||||
if x not in inputs:
|
||||
if input_category == "required":
|
||||
details = f"{x}" if not v3_data else x.split(".")[-1]
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
@ -916,8 +923,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if isinstance(input_type, list):
|
||||
combo_options = input_type
|
||||
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||||
if input_type == io.Combo.io_type:
|
||||
combo_options = extra_info.get("options", [])
|
||||
else:
|
||||
combo_options = input_type
|
||||
if val not in combo_options:
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
66
main.py
66
main.py
@ -23,6 +23,38 @@ if __name__ == "__main__":
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
if os.name == "nt":
|
||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
||||
if args.default_device is not None:
|
||||
default_dev = args.default_device
|
||||
devices = list(range(32))
|
||||
devices.remove(default_dev)
|
||||
devices.insert(0, default_dev)
|
||||
devices = ','.join(map(str, devices))
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
|
||||
if args.oneapi_device_selector is not None:
|
||||
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
||||
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
import cuda_malloc
|
||||
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||
|
||||
|
||||
def handle_comfyui_manager_unavailable():
|
||||
if not args.windows_standalone_build:
|
||||
@ -137,40 +169,6 @@ import shutil
|
||||
import threading
|
||||
import gc
|
||||
|
||||
|
||||
if os.name == "nt":
|
||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
||||
if args.default_device is not None:
|
||||
default_dev = args.default_device
|
||||
devices = list(range(32))
|
||||
devices.remove(default_dev)
|
||||
devices.insert(0, default_dev)
|
||||
devices = ','.join(map(str, devices))
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
|
||||
if args.oneapi_device_selector is not None:
|
||||
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
||||
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
import cuda_malloc
|
||||
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||
|
||||
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.0.3b5
|
||||
comfyui_manager==4.0.4
|
||||
|
||||
22
nodes.py
22
nodes.py
@ -343,7 +343,7 @@ class VAEEncode:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
t = vae.encode(pixels)
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeTiled:
|
||||
@ -361,7 +361,7 @@ class VAEEncodeTiled:
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
return ({"samples": t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
@ -970,7 +970,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -980,7 +980,7 @@ class DualCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@ -1663,8 +1663,6 @@ class LoadImage:
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
excluded_formats = ['MPO']
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
@ -1692,7 +1690,10 @@ class LoadImage:
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
if len(output_images) > 1 and img.format not in excluded_formats:
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
@ -1863,6 +1864,7 @@ class ImageBatch:
|
||||
FUNCTION = "batch"
|
||||
|
||||
CATEGORY = "image"
|
||||
DEPRECATED = True
|
||||
|
||||
def batch(self, image1, image2):
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
@ -2241,8 +2243,10 @@ async def init_external_custom_nodes():
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py":
|
||||
continue
|
||||
if module_path.endswith(".disabled"):
|
||||
continue
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.4.0"
|
||||
version = "0.7.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
@ -15,12 +15,16 @@ lint.select = [
|
||||
"N805", # invalid-first-argument-name-for-method
|
||||
"S307", # suspicious-eval-usage
|
||||
"S102", # exec
|
||||
"E",
|
||||
"T", # print-usage
|
||||
"W",
|
||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
|
||||
lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"]
|
||||
|
||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||
|
||||
[tool.pylint]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.34.9
|
||||
comfyui-workflow-templates==0.7.59
|
||||
comfyui-frontend-package==1.35.9
|
||||
comfyui-workflow-templates==0.7.65
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
141
server.py
141
server.py
@ -7,6 +7,7 @@ import time
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||
import uuid
|
||||
import urllib
|
||||
import json
|
||||
@ -47,6 +48,12 @@ from middleware.cache_middleware import cache_control
|
||||
if args.enable_manager:
|
||||
import comfyui_manager
|
||||
|
||||
|
||||
def _remove_sensitive_from_queue(queue: list) -> list:
|
||||
"""Remove sensitive data (index 5) from queue item tuples."""
|
||||
return [item[:5] for item in queue]
|
||||
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
@ -317,7 +324,7 @@ class PromptServer():
|
||||
@routes.get("/models/{folder}")
|
||||
async def get_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = folder_paths.get_filename_list(folder)
|
||||
return web.json_response(files)
|
||||
@ -572,7 +579,7 @@ class PromptServer():
|
||||
folder_name = request.match_info.get("folder_name", None)
|
||||
if folder_name is None:
|
||||
return web.Response(status=404)
|
||||
if not "filename" in request.rel_url.query:
|
||||
if "filename" not in request.rel_url.query:
|
||||
return web.Response(status=404)
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
@ -586,7 +593,7 @@ class PromptServer():
|
||||
if out is None:
|
||||
return web.Response(status=404)
|
||||
dt = json.loads(out)
|
||||
if not "__metadata__" in dt:
|
||||
if "__metadata__" not in dt:
|
||||
return web.Response(status=404)
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@ -694,6 +701,129 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
|
||||
Query parameters:
|
||||
status: Filter by status (comma-separated): pending, in_progress, completed, failed
|
||||
workflow_id: Filter by workflow ID
|
||||
sort_by: Sort field: created_at (default), execution_duration
|
||||
sort_order: Sort direction: asc, desc (default)
|
||||
limit: Max items to return (positive integer)
|
||||
offset: Items to skip (non-negative integer, default 0)
|
||||
"""
|
||||
query = request.rel_url.query
|
||||
|
||||
status_param = query.get('status')
|
||||
workflow_id = query.get('workflow_id')
|
||||
sort_by = query.get('sort_by', 'created_at').lower()
|
||||
sort_order = query.get('sort_order', 'desc').lower()
|
||||
|
||||
status_filter = None
|
||||
if status_param:
|
||||
status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()]
|
||||
invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL]
|
||||
if invalid_statuses:
|
||||
return web.json_response(
|
||||
{"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"},
|
||||
status=400
|
||||
)
|
||||
|
||||
if sort_by not in {'created_at', 'execution_duration'}:
|
||||
return web.json_response(
|
||||
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
|
||||
status=400
|
||||
)
|
||||
|
||||
if sort_order not in {'asc', 'desc'}:
|
||||
return web.json_response(
|
||||
{"error": "sort_order must be 'asc' or 'desc'"},
|
||||
status=400
|
||||
)
|
||||
|
||||
limit = None
|
||||
|
||||
# If limit is provided, validate that it is a positive integer, else continue without a limit
|
||||
if 'limit' in query:
|
||||
try:
|
||||
limit = int(query.get('limit'))
|
||||
if limit <= 0:
|
||||
return web.json_response(
|
||||
{"error": "limit must be a positive integer"},
|
||||
status=400
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return web.json_response(
|
||||
{"error": "limit must be an integer"},
|
||||
status=400
|
||||
)
|
||||
|
||||
offset = 0
|
||||
if 'offset' in query:
|
||||
try:
|
||||
offset = int(query.get('offset'))
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
except (ValueError, TypeError):
|
||||
return web.json_response(
|
||||
{"error": "offset must be an integer"},
|
||||
status=400
|
||||
)
|
||||
|
||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||
history = self.prompt_queue.get_history()
|
||||
|
||||
running = _remove_sensitive_from_queue(running)
|
||||
queued = _remove_sensitive_from_queue(queued)
|
||||
|
||||
jobs, total = get_all_jobs(
|
||||
running, queued, history,
|
||||
status_filter=status_filter,
|
||||
workflow_id=workflow_id,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
has_more = (offset + len(jobs)) < total
|
||||
|
||||
return web.json_response({
|
||||
'jobs': jobs,
|
||||
'pagination': {
|
||||
'offset': offset,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'has_more': has_more
|
||||
}
|
||||
})
|
||||
|
||||
@routes.get("/api/jobs/{job_id}")
|
||||
async def get_job_by_id(request):
|
||||
"""Get a single job by ID."""
|
||||
job_id = request.match_info.get("job_id", None)
|
||||
if not job_id:
|
||||
return web.json_response(
|
||||
{"error": "job_id is required"},
|
||||
status=400
|
||||
)
|
||||
|
||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||
history = self.prompt_queue.get_history(prompt_id=job_id)
|
||||
|
||||
running = _remove_sensitive_from_queue(running)
|
||||
queued = _remove_sensitive_from_queue(queued)
|
||||
|
||||
job = get_job(job_id, running, queued, history)
|
||||
if job is None:
|
||||
return web.json_response(
|
||||
{"error": "Job not found"},
|
||||
status=404
|
||||
)
|
||||
|
||||
return web.json_response(job)
|
||||
|
||||
@routes.get("/history")
|
||||
async def get_history(request):
|
||||
max_items = request.rel_url.query.get("max_items", None)
|
||||
@ -717,9 +847,8 @@ class PromptServer():
|
||||
async def get_queue(request):
|
||||
queue_info = {}
|
||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||
remove_sensitive = lambda queue: [x[:5] for x in queue]
|
||||
queue_info['queue_running'] = remove_sensitive(current_queue[0])
|
||||
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
|
||||
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
|
||||
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
|
||||
return web.json_response(queue_info)
|
||||
|
||||
@routes.post("/prompt")
|
||||
|
||||
@ -25,7 +25,7 @@ class TestImageStitch:
|
||||
|
||||
result = node.stitch(image1, "right", True, 0, "white", image2=None)
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result.result) == 1
|
||||
assert torch.equal(result[0], image1)
|
||||
|
||||
def test_basic_horizontal_stitch_right(self):
|
||||
|
||||
@ -99,6 +99,37 @@ class ComfyClient:
|
||||
with urllib.request.urlopen(url) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_jobs(self, status=None, limit=None, offset=None, sort_by=None, sort_order=None):
|
||||
url = "http://{}/api/jobs".format(self.server_address)
|
||||
params = {}
|
||||
if status is not None:
|
||||
params["status"] = status
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
if sort_by is not None:
|
||||
params["sort_by"] = sort_by
|
||||
if sort_order is not None:
|
||||
params["sort_order"] = sort_order
|
||||
|
||||
if params:
|
||||
url_values = urllib.parse.urlencode(params)
|
||||
url = "{}?{}".format(url, url_values)
|
||||
|
||||
with urllib.request.urlopen(url) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_job(self, job_id):
|
||||
url = "http://{}/api/jobs/{}".format(self.server_address, job_id)
|
||||
try:
|
||||
with urllib.request.urlopen(url) as response:
|
||||
return json.loads(response.read())
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
return None
|
||||
raise
|
||||
|
||||
def set_test_name(self, name):
|
||||
self.test_name = name
|
||||
|
||||
@ -877,3 +908,106 @@ class TestExecution:
|
||||
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
|
||||
|
||||
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
||||
|
||||
# Jobs API tests
|
||||
def test_jobs_api_job_structure(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test that job objects have required fields"""
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
||||
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
||||
|
||||
job = jobs_response["jobs"][0]
|
||||
assert "id" in job, "Job should have id"
|
||||
assert "status" in job, "Job should have status"
|
||||
assert "create_time" in job, "Job should have create_time"
|
||||
assert "outputs_count" in job, "Job should have outputs_count"
|
||||
assert "preview_output" in job, "Job should have preview_output"
|
||||
|
||||
def test_jobs_api_preview_output_structure(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test that preview_output has correct structure"""
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
||||
job = jobs_response["jobs"][0]
|
||||
|
||||
if job["preview_output"] is not None:
|
||||
preview = job["preview_output"]
|
||||
assert "filename" in preview, "Preview should have filename"
|
||||
assert "nodeId" in preview, "Preview should have nodeId"
|
||||
assert "mediaType" in preview, "Preview should have mediaType"
|
||||
|
||||
def test_jobs_api_pagination(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API pagination"""
|
||||
for _ in range(5):
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
first_page = client.get_jobs(limit=2, offset=0)
|
||||
second_page = client.get_jobs(limit=2, offset=2)
|
||||
|
||||
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
||||
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
||||
|
||||
first_ids = {j["id"] for j in first_page["jobs"]}
|
||||
second_ids = {j["id"] for j in second_page["jobs"]}
|
||||
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
||||
|
||||
def test_jobs_api_sorting(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API sorting"""
|
||||
for _ in range(3):
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
desc_jobs = client.get_jobs(sort_order="desc")
|
||||
asc_jobs = client.get_jobs(sort_order="asc")
|
||||
|
||||
if len(desc_jobs["jobs"]) >= 2:
|
||||
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
||||
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
|
||||
if len(desc_times) >= 2:
|
||||
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
|
||||
if len(asc_times) >= 2:
|
||||
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
||||
|
||||
def test_jobs_api_status_filter(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test jobs API status filtering"""
|
||||
self._create_history_item(client, builder)
|
||||
|
||||
completed_jobs = client.get_jobs(status="completed")
|
||||
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
||||
|
||||
for job in completed_jobs["jobs"]:
|
||||
assert job["status"] == "completed", "Should only return completed jobs"
|
||||
|
||||
# Pending jobs are transient - just verify filter doesn't error
|
||||
pending_jobs = client.get_jobs(status="pending")
|
||||
for job in pending_jobs["jobs"]:
|
||||
assert job["status"] == "pending", "Should only return pending jobs"
|
||||
|
||||
def test_get_job_by_id(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test getting a single job by ID"""
|
||||
result = self._create_history_item(client, builder)
|
||||
prompt_id = result.get_prompt_id()
|
||||
|
||||
job = client.get_job(prompt_id)
|
||||
assert job is not None, "Should find the job"
|
||||
assert job["id"] == prompt_id, "Job ID should match"
|
||||
assert "outputs" in job, "Single job should include outputs"
|
||||
|
||||
def test_get_job_not_found(
|
||||
self, client: ComfyClient, builder: GraphBuilder
|
||||
):
|
||||
"""Test getting a non-existent job returns 404"""
|
||||
job = client.get_job("nonexistent-job-id")
|
||||
assert job is None, "Non-existent job should return None"
|
||||
|
||||
361
tests/execution/test_jobs.py
Normal file
361
tests/execution/test_jobs.py
Normal file
@ -0,0 +1,361 @@
|
||||
"""Unit tests for comfy_execution/jobs.py"""
|
||||
|
||||
from comfy_execution.jobs import (
|
||||
JobStatus,
|
||||
is_previewable,
|
||||
normalize_queue_item,
|
||||
normalize_history_item,
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
)
|
||||
|
||||
|
||||
class TestJobStatus:
|
||||
"""Test JobStatus constants."""
|
||||
|
||||
def test_status_values(self):
|
||||
"""Status constants should have expected string values."""
|
||||
assert JobStatus.PENDING == 'pending'
|
||||
assert JobStatus.IN_PROGRESS == 'in_progress'
|
||||
assert JobStatus.COMPLETED == 'completed'
|
||||
assert JobStatus.FAILED == 'failed'
|
||||
|
||||
def test_all_contains_all_statuses(self):
|
||||
"""ALL should contain all status values."""
|
||||
assert JobStatus.PENDING in JobStatus.ALL
|
||||
assert JobStatus.IN_PROGRESS in JobStatus.ALL
|
||||
assert JobStatus.COMPLETED in JobStatus.ALL
|
||||
assert JobStatus.FAILED in JobStatus.ALL
|
||||
assert len(JobStatus.ALL) == 4
|
||||
|
||||
|
||||
class TestIsPreviewable:
|
||||
"""Unit tests for is_previewable()"""
|
||||
|
||||
def test_previewable_media_types(self):
|
||||
"""Images, video, audio media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio']:
|
||||
assert is_previewable(media_type, {}) is True
|
||||
|
||||
def test_non_previewable_media_types(self):
|
||||
"""Other media types should not be previewable."""
|
||||
for media_type in ['latents', 'text', 'metadata', 'files']:
|
||||
assert is_previewable(media_type, {}) is False
|
||||
|
||||
def test_3d_extensions_previewable(self):
|
||||
"""3D file extensions should be previewable regardless of media_type."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
item = {'filename': f'model{ext}'}
|
||||
assert is_previewable('files', item) is True
|
||||
|
||||
def test_3d_extensions_case_insensitive(self):
|
||||
"""3D extension check should be case insensitive."""
|
||||
item = {'filename': 'MODEL.GLB'}
|
||||
assert is_previewable('files', item) is True
|
||||
|
||||
def test_video_format_previewable(self):
|
||||
"""Items with video/ format should be previewable."""
|
||||
item = {'format': 'video/mp4'}
|
||||
assert is_previewable('files', item) is True
|
||||
|
||||
def test_audio_format_previewable(self):
|
||||
"""Items with audio/ format should be previewable."""
|
||||
item = {'format': 'audio/wav'}
|
||||
assert is_previewable('files', item) is True
|
||||
|
||||
def test_other_format_not_previewable(self):
|
||||
"""Items with other format should not be previewable."""
|
||||
item = {'format': 'application/json'}
|
||||
assert is_previewable('files', item) is False
|
||||
|
||||
|
||||
class TestGetOutputsSummary:
|
||||
"""Unit tests for get_outputs_summary()"""
|
||||
|
||||
def test_empty_outputs(self):
|
||||
"""Empty outputs should return 0 count and None preview."""
|
||||
count, preview = get_outputs_summary({})
|
||||
assert count == 0
|
||||
assert preview is None
|
||||
|
||||
def test_counts_across_multiple_nodes(self):
|
||||
"""Outputs from multiple nodes should all be counted."""
|
||||
outputs = {
|
||||
'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]},
|
||||
'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]},
|
||||
'node3': {'images': [
|
||||
{'filename': 'c.png', 'type': 'output'},
|
||||
{'filename': 'd.png', 'type': 'output'}
|
||||
]}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 4
|
||||
|
||||
def test_skips_animated_key_and_non_list_values(self):
|
||||
"""The 'animated' key and non-list values should be skipped."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [{'filename': 'test.png', 'type': 'output'}],
|
||||
'animated': [True], # Should skip due to key name
|
||||
'metadata': 'string', # Should skip due to non-list
|
||||
'count': 42 # Should skip due to non-list
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 1
|
||||
|
||||
def test_preview_prefers_type_output(self):
|
||||
"""Items with type='output' should be preferred for preview."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [
|
||||
{'filename': 'temp.png', 'type': 'temp'},
|
||||
{'filename': 'output.png', 'type': 'output'}
|
||||
]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 2
|
||||
assert preview['filename'] == 'output.png'
|
||||
|
||||
def test_preview_fallback_when_no_output_type(self):
|
||||
"""If no type='output', should use first previewable."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [
|
||||
{'filename': 'temp1.png', 'type': 'temp'},
|
||||
{'filename': 'temp2.png', 'type': 'temp'}
|
||||
]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview['filename'] == 'temp1.png'
|
||||
|
||||
def test_non_previewable_media_types_counted_but_no_preview(self):
|
||||
"""Non-previewable media types should be counted but not used as preview."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'latents': [
|
||||
{'filename': 'latent1.safetensors'},
|
||||
{'filename': 'latent2.safetensors'}
|
||||
]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 2
|
||||
assert preview is None
|
||||
|
||||
def test_previewable_media_types(self):
|
||||
"""Images, video, and audio media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio']:
|
||||
outputs = {
|
||||
'node1': {
|
||||
media_type: [{'filename': 'test.file', 'type': 'output'}]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview is not None, f"{media_type} should be previewable"
|
||||
|
||||
def test_3d_files_previewable(self):
|
||||
"""3D file extensions should be previewable."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
outputs = {
|
||||
'node1': {
|
||||
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview is not None, f"3D file {ext} should be previewable"
|
||||
|
||||
def test_format_mime_type_previewable(self):
|
||||
"""Files with video/ or audio/ format should be previewable."""
|
||||
for fmt in ['video/x-custom', 'audio/x-custom']:
|
||||
outputs = {
|
||||
'node1': {
|
||||
'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview is not None, f"Format {fmt} should be previewable"
|
||||
|
||||
def test_preview_enriched_with_node_metadata(self):
|
||||
"""Preview should include nodeId, mediaType, and original fields."""
|
||||
outputs = {
|
||||
'node123': {
|
||||
'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview['nodeId'] == 'node123'
|
||||
assert preview['mediaType'] == 'images'
|
||||
assert preview['subfolder'] == 'outputs'
|
||||
|
||||
|
||||
class TestApplySorting:
|
||||
"""Unit tests for apply_sorting()"""
|
||||
|
||||
def test_sort_by_create_time_desc(self):
|
||||
"""Default sort by create_time descending."""
|
||||
jobs = [
|
||||
{'id': 'a', 'create_time': 100},
|
||||
{'id': 'b', 'create_time': 300},
|
||||
{'id': 'c', 'create_time': 200},
|
||||
]
|
||||
result = apply_sorting(jobs, 'created_at', 'desc')
|
||||
assert [j['id'] for j in result] == ['b', 'c', 'a']
|
||||
|
||||
def test_sort_by_create_time_asc(self):
|
||||
"""Sort by create_time ascending."""
|
||||
jobs = [
|
||||
{'id': 'a', 'create_time': 100},
|
||||
{'id': 'b', 'create_time': 300},
|
||||
{'id': 'c', 'create_time': 200},
|
||||
]
|
||||
result = apply_sorting(jobs, 'created_at', 'asc')
|
||||
assert [j['id'] for j in result] == ['a', 'c', 'b']
|
||||
|
||||
def test_sort_by_execution_duration(self):
|
||||
"""Sort by execution_duration should order by duration."""
|
||||
jobs = [
|
||||
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s
|
||||
{'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s
|
||||
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s
|
||||
]
|
||||
result = apply_sorting(jobs, 'execution_duration', 'desc')
|
||||
assert [j['id'] for j in result] == ['a', 'c', 'b']
|
||||
|
||||
def test_sort_with_none_values(self):
|
||||
"""Jobs with None values should sort as 0."""
|
||||
jobs = [
|
||||
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100},
|
||||
{'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None},
|
||||
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200},
|
||||
]
|
||||
result = apply_sorting(jobs, 'execution_duration', 'asc')
|
||||
assert result[0]['id'] == 'b' # None treated as 0, comes first
|
||||
|
||||
|
||||
class TestNormalizeQueueItem:
|
||||
"""Unit tests for normalize_queue_item()"""
|
||||
|
||||
def test_basic_normalization(self):
|
||||
"""Queue item should be normalized to job dict."""
|
||||
item = (
|
||||
10, # priority/number
|
||||
'prompt-123', # prompt_id
|
||||
{'nodes': {}}, # prompt
|
||||
{
|
||||
'create_time': 1234567890,
|
||||
'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}}
|
||||
}, # extra_data
|
||||
['node1'], # outputs_to_execute
|
||||
)
|
||||
job = normalize_queue_item(item, JobStatus.PENDING)
|
||||
|
||||
assert job['id'] == 'prompt-123'
|
||||
assert job['status'] == 'pending'
|
||||
assert job['priority'] == 10
|
||||
assert job['create_time'] == 1234567890
|
||||
assert 'execution_start_time' not in job
|
||||
assert 'execution_end_time' not in job
|
||||
assert 'execution_error' not in job
|
||||
assert 'preview_output' not in job
|
||||
assert job['outputs_count'] == 0
|
||||
assert job['workflow_id'] == 'workflow-abc'
|
||||
|
||||
|
||||
class TestNormalizeHistoryItem:
|
||||
"""Unit tests for normalize_history_item()"""
|
||||
|
||||
def test_completed_job(self):
|
||||
"""Completed history item should have correct status and times from messages."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5, # priority
|
||||
'prompt-456',
|
||||
{'nodes': {}},
|
||||
{
|
||||
'create_time': 1234567890000,
|
||||
'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}}
|
||||
},
|
||||
['node1'],
|
||||
),
|
||||
'status': {
|
||||
'status_str': 'success',
|
||||
'completed': True,
|
||||
'messages': [
|
||||
('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}),
|
||||
('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}),
|
||||
]
|
||||
},
|
||||
'outputs': {},
|
||||
}
|
||||
job = normalize_history_item('prompt-456', history_item)
|
||||
|
||||
assert job['id'] == 'prompt-456'
|
||||
assert job['status'] == 'completed'
|
||||
assert job['priority'] == 5
|
||||
assert job['execution_start_time'] == 1234567890500
|
||||
assert job['execution_end_time'] == 1234567893000
|
||||
assert job['workflow_id'] == 'workflow-xyz'
|
||||
|
||||
def test_failed_job(self):
|
||||
"""Failed history item should have failed status and error from messages."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-789',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890000},
|
||||
['node1'],
|
||||
),
|
||||
'status': {
|
||||
'status_str': 'error',
|
||||
'completed': False,
|
||||
'messages': [
|
||||
('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}),
|
||||
('execution_error', {
|
||||
'prompt_id': 'prompt-789',
|
||||
'node_id': '5',
|
||||
'node_type': 'KSampler',
|
||||
'exception_message': 'CUDA out of memory',
|
||||
'exception_type': 'RuntimeError',
|
||||
'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'],
|
||||
'timestamp': 1234567891000,
|
||||
})
|
||||
]
|
||||
},
|
||||
'outputs': {},
|
||||
}
|
||||
|
||||
job = normalize_history_item('prompt-789', history_item)
|
||||
assert job['status'] == 'failed'
|
||||
assert job['execution_start_time'] == 1234567890500
|
||||
assert job['execution_end_time'] == 1234567891000
|
||||
assert job['execution_error']['node_id'] == '5'
|
||||
assert job['execution_error']['node_type'] == 'KSampler'
|
||||
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
|
||||
|
||||
def test_include_outputs(self):
|
||||
"""When include_outputs=True, should include full output data."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-123',
|
||||
{'nodes': {'1': {}}},
|
||||
{'create_time': 1234567890, 'client_id': 'abc'},
|
||||
['node1'],
|
||||
),
|
||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||
'outputs': {'node1': {'images': [{'filename': 'test.png'}]}},
|
||||
}
|
||||
job = normalize_history_item('prompt-123', history_item, include_outputs=True)
|
||||
|
||||
assert 'outputs' in job
|
||||
assert 'workflow' in job
|
||||
assert 'execution_status' in job
|
||||
assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}}
|
||||
assert job['workflow'] == {
|
||||
'prompt': {'nodes': {'1': {}}},
|
||||
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user