mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 02:53:05 +08:00
Compare commits
No commits in common. "master" and "v0.4.0" have entirely different histories.
@ -53,16 +53,6 @@ try:
|
|||||||
repo.stash(ident)
|
repo.stash(ident)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("nothing to stash") # noqa: T201
|
print("nothing to stash") # noqa: T201
|
||||||
except:
|
|
||||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
|
||||||
repo.state_cleanup()
|
|
||||||
repo.index.read_tree(repo.head.peel().tree)
|
|
||||||
repo.index.write()
|
|
||||||
try:
|
|
||||||
repo.stash(ident)
|
|
||||||
except KeyError:
|
|
||||||
print("nothing to stash.") # noqa: T201
|
|
||||||
|
|
||||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||||
try:
|
try:
|
||||||
|
|||||||
1
.github/workflows/test-ci.yml
vendored
1
.github/workflows/test-ci.yml
vendored
@ -5,7 +5,6 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
- release/**
|
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- 'app/**'
|
- 'app/**'
|
||||||
- 'input/**'
|
- 'input/**'
|
||||||
|
|||||||
4
.github/workflows/test-execution.yml
vendored
4
.github/workflows/test-execution.yml
vendored
@ -2,9 +2,9 @@ name: Execution Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master, release/** ]
|
branches: [ main, master ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master, release/** ]
|
branches: [ main, master ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
4
.github/workflows/test-launch.yml
vendored
4
.github/workflows/test-launch.yml
vendored
@ -2,9 +2,9 @@ name: Test server launches without errors
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master, release/** ]
|
branches: [ main, master ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master, release/** ]
|
branches: [ main, master ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@ -2,9 +2,9 @@ name: Unit Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master, release/** ]
|
branches: [ main, master ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master, release/** ]
|
branches: [ main, master ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@ -6,7 +6,6 @@ on:
|
|||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
- release/**
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
update-version:
|
update-version:
|
||||||
|
|||||||
@ -58,13 +58,8 @@ class InternalRoutes:
|
|||||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||||
|
|
||||||
directory = get_directory_by_type(directory_type)
|
directory = get_directory_by_type(directory_type)
|
||||||
|
|
||||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
|
||||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
|
||||||
return entry.is_file() and not entry.name.startswith('.')
|
|
||||||
|
|
||||||
sorted_files = sorted(
|
sorted_files = sorted(
|
||||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||||
key=lambda entry: -entry.stat().st_mtime
|
key=lambda entry: -entry.stat().st_mtime
|
||||||
)
|
)
|
||||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||||
|
|||||||
@ -97,13 +97,6 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
Latent2RGB = "latent2rgb"
|
Latent2RGB = "latent2rgb"
|
||||||
TAESD = "taesd"
|
TAESD = "taesd"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_string(cls, value: str):
|
|
||||||
for member in cls:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
return None
|
|
||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|||||||
@ -87,7 +87,6 @@ class IndexListCallbacks:
|
|||||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||||
EXECUTE_START = "execute_start"
|
EXECUTE_START = "execute_start"
|
||||||
EXECUTE_CLEANUP = "execute_cleanup"
|
EXECUTE_CLEANUP = "execute_cleanup"
|
||||||
RESIZE_COND_ITEM = "resize_cond_item"
|
|
||||||
|
|
||||||
def init_callbacks(self):
|
def init_callbacks(self):
|
||||||
return {}
|
return {}
|
||||||
@ -167,18 +166,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
new_cond_item = cond_item.copy()
|
new_cond_item = cond_item.copy()
|
||||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||||
for cond_key, cond_value in new_cond_item.items():
|
for cond_key, cond_value in new_cond_item.items():
|
||||||
# Allow callbacks to handle custom conditioning items
|
|
||||||
handled = False
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
|
||||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
|
||||||
):
|
|
||||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
|
||||||
if result is not None:
|
|
||||||
new_cond_item[cond_key] = result
|
|
||||||
handled = True
|
|
||||||
break
|
|
||||||
if handled:
|
|
||||||
continue
|
|
||||||
if isinstance(cond_value, torch.Tensor):
|
if isinstance(cond_value, torch.Tensor):
|
||||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||||
|
|||||||
@ -1557,13 +1557,10 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
if solver_type not in {"phi_1", "phi_2"}:
|
|
||||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1603,14 +1600,8 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
if solver_type == "phi_1":
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
|
||||||
elif solver_type == "phi_2":
|
|
||||||
b2 = ei_h_phi_2(-h_eta) / r
|
|
||||||
b1 = ei_h_phi_1(-h_eta) - b2
|
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
segment_factor = (r - 1) * h * eta
|
segment_factor = (r - 1) * h * eta
|
||||||
sde_noise = sde_noise * segment_factor.exp()
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
@ -1618,17 +1609,6 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
|
||||||
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
|
||||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
|
||||||
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
|
||||||
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||||
@ -1776,7 +1756,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
# Predictor
|
# Predictor
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Denoising step
|
# Denoising step
|
||||||
x_pred = denoised
|
x = denoised
|
||||||
else:
|
else:
|
||||||
tau_t = tau_func(sigmas[i + 1])
|
tau_t = tau_func(sigmas[i + 1])
|
||||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||||
@ -1797,7 +1777,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
if tau_t > 0 and s_noise > 0:
|
if tau_t > 0 and s_noise > 0:
|
||||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||||
x_pred = x_pred + noise
|
x_pred = x_pred + noise
|
||||||
return x_pred
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -41,11 +41,6 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
n_control_layers=6,
|
|
||||||
control_in_dim=16,
|
|
||||||
additional_in_dim=0,
|
|
||||||
broken=False,
|
|
||||||
refiner_control=False,
|
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -54,11 +49,10 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.broken = broken
|
self.additional_in_dim = 0
|
||||||
self.additional_in_dim = additional_in_dim
|
self.control_in_dim = 16
|
||||||
self.control_in_dim = control_in_dim
|
|
||||||
n_refiner_layers = 2
|
n_refiner_layers = 2
|
||||||
self.n_control_layers = n_control_layers
|
self.n_control_layers = 6
|
||||||
self.control_layers = nn.ModuleList(
|
self.control_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ZImageControlTransformerBlock(
|
ZImageControlTransformerBlock(
|
||||||
@ -80,49 +74,28 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
all_x_embedder = {}
|
all_x_embedder = {}
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
f_patch_size = 1
|
f_patch_size = 1
|
||||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
||||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
self.refiner_control = refiner_control
|
|
||||||
|
|
||||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
if self.refiner_control:
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
self.control_noise_refiner = nn.ModuleList(
|
[
|
||||||
[
|
JointTransformerBlock(
|
||||||
ZImageControlTransformerBlock(
|
layer_id,
|
||||||
layer_id,
|
dim,
|
||||||
dim,
|
n_heads,
|
||||||
n_heads,
|
n_kv_heads,
|
||||||
n_kv_heads,
|
multiple_of,
|
||||||
multiple_of,
|
ffn_dim_multiplier,
|
||||||
ffn_dim_multiplier,
|
norm_eps,
|
||||||
norm_eps,
|
qk_norm,
|
||||||
qk_norm,
|
modulation=True,
|
||||||
block_id=layer_id,
|
z_image_modulation=True,
|
||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
for layer_id in range(n_refiner_layers)
|
for layer_id in range(n_refiner_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.control_noise_refiner = nn.ModuleList(
|
|
||||||
[
|
|
||||||
JointTransformerBlock(
|
|
||||||
layer_id,
|
|
||||||
dim,
|
|
||||||
n_heads,
|
|
||||||
n_kv_heads,
|
|
||||||
multiple_of,
|
|
||||||
ffn_dim_multiplier,
|
|
||||||
norm_eps,
|
|
||||||
qk_norm,
|
|
||||||
modulation=True,
|
|
||||||
z_image_modulation=True,
|
|
||||||
operation_settings=operation_settings,
|
|
||||||
)
|
|
||||||
for layer_id in range(n_refiner_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
@ -132,29 +105,9 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
x_attn_mask = None
|
x_attn_mask = None
|
||||||
if not self.refiner_control:
|
for layer in self.control_noise_refiner:
|
||||||
for layer in self.control_noise_refiner:
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
|
||||||
|
|
||||||
return control_context
|
return control_context
|
||||||
|
|
||||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
|
||||||
if self.refiner_control:
|
|
||||||
if self.broken:
|
|
||||||
if layer_id == 0:
|
|
||||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
|
||||||
if layer_id > 0:
|
|
||||||
out = None
|
|
||||||
for i in range(1, len(self.control_layers)):
|
|
||||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
|
||||||
if out is None:
|
|
||||||
out = o
|
|
||||||
|
|
||||||
return (out, control_context)
|
|
||||||
else:
|
|
||||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
|
||||||
else:
|
|
||||||
return (None, control_context)
|
|
||||||
|
|
||||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
|||||||
@ -536,7 +536,6 @@ class NextDiT(nn.Module):
|
|||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
pH = pW = self.patch_size
|
pH = pW = self.patch_size
|
||||||
device = x[0].device
|
device = x[0].device
|
||||||
orig_x = x
|
|
||||||
|
|
||||||
if self.pad_tokens_multiple is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||||
@ -573,21 +572,13 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
|
||||||
|
|
||||||
# refine context
|
# refine context
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||||
|
|
||||||
padded_img_mask = None
|
padded_img_mask = None
|
||||||
x_input = x
|
for layer in self.noise_refiner:
|
||||||
for i, layer in enumerate(self.noise_refiner):
|
|
||||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||||
if "noise_refiner" in patches:
|
|
||||||
for p in patches["noise_refiner"]:
|
|
||||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
|
||||||
if "img" in out:
|
|
||||||
x = out["img"]
|
|
||||||
|
|
||||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||||
mask = None
|
mask = None
|
||||||
@ -631,18 +622,14 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(img.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.layers)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
img_input = img
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
transformer_options["block_index"] = i
|
|
||||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
if "double_block" in patches:
|
if "double_block" in patches:
|
||||||
for p in patches["double_block"]:
|
for p in patches["double_block"]:
|
||||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
if "img" in out:
|
if "img" in out:
|
||||||
img[:, cap_size[0]:] = out["img"]
|
img[:, cap_size[0]:] = out["img"]
|
||||||
if "txt" in out:
|
if "txt" in out:
|
||||||
|
|||||||
@ -218,24 +218,9 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
|
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if timestep_zero_index is not None:
|
|
||||||
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
|
|
||||||
else:
|
|
||||||
return torch.addcmul(y, gate, x)
|
|
||||||
|
|
||||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||||
if timestep_zero_index is not None:
|
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||||
actual_batch = shift.size(0) // 2
|
|
||||||
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
|
|
||||||
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
|
|
||||||
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
|
|
||||||
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
|
|
||||||
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
|
|
||||||
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
|
|
||||||
else:
|
|
||||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -244,19 +229,14 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states_mask: torch.Tensor,
|
encoder_hidden_states_mask: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
timestep_zero_index=None,
|
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
img_mod_params = self.img_mod(temb)
|
img_mod_params = self.img_mod(temb)
|
||||||
|
|
||||||
if timestep_zero_index is not None:
|
|
||||||
temb = temb.chunk(2, dim=0)[0]
|
|
||||||
|
|
||||||
txt_mod_params = self.txt_mod(temb)
|
txt_mod_params = self.txt_mod(temb)
|
||||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||||
|
|
||||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
|
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||||
del img_mod1
|
del img_mod1
|
||||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||||
del txt_mod1
|
del txt_mod1
|
||||||
@ -271,15 +251,15 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
del img_modulated
|
del img_modulated
|
||||||
del txt_modulated
|
del txt_modulated
|
||||||
|
|
||||||
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
|
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||||
del img_attn_output
|
del img_attn_output
|
||||||
del txt_attn_output
|
del txt_attn_output
|
||||||
del img_gate1
|
del img_gate1
|
||||||
del txt_gate1
|
del txt_gate1
|
||||||
|
|
||||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
|
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||||
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
|
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||||
|
|
||||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||||
@ -322,7 +302,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
guidance_embeds: bool = False,
|
guidance_embeds: bool = False,
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
default_ref_method="index",
|
|
||||||
image_model=None,
|
image_model=None,
|
||||||
final_layer=True,
|
final_layer=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -335,7 +314,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
self.default_ref_method = default_ref_method
|
|
||||||
|
|
||||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
|
|
||||||
@ -363,9 +341,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
if self.default_ref_method == "index_timestep_zero":
|
|
||||||
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
@ -416,14 +391,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
num_embeds = hidden_states.shape[1]
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
timestep_zero_index = None
|
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
|
||||||
timestep_zero = ref_method == "index_timestep_zero"
|
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if index_ref_method:
|
if index_ref_method:
|
||||||
index += 1
|
index += 1
|
||||||
@ -443,10 +415,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
if timestep_zero:
|
|
||||||
if index > 0:
|
|
||||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
|
||||||
timestep_zero_index = num_embeds
|
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
@ -478,7 +446,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
|
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
@ -490,7 +458,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
timestep_zero_index=timestep_zero_index,
|
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -507,9 +474,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states[:, :add.shape[1]] += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
if timestep_zero_index is not None:
|
|
||||||
temb = temb.chunk(2, dim=0)[0]
|
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@ -568,10 +568,7 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -766,10 +763,7 @@ class VaceWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -868,10 +862,7 @@ class CameraWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -1335,19 +1326,16 @@ class WanModel_S2V(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
|
x = block(x, e=e0, freqs=freqs, context=context)
|
||||||
if audio_emb is not None:
|
if audio_emb is not None:
|
||||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||||
# head
|
# head
|
||||||
@ -1586,10 +1574,7 @@ class HumoWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -523,10 +523,7 @@ class AnimateWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -259,10 +259,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
if "__x0__" in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
else:
|
|
||||||
dit_config["use_x0"] = False
|
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
@ -618,8 +616,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["image_model"] = "qwen_image"
|
dit_config["image_model"] = "qwen_image"
|
||||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
|
||||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
|||||||
@ -454,9 +454,6 @@ class ModelPatcher:
|
|||||||
def set_model_post_input_patch(self, patch):
|
def set_model_post_input_patch(self, patch):
|
||||||
self.set_model_patch(patch, "post_input")
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
def set_model_noise_refiner_patch(self, patch):
|
|
||||||
self.set_model_patch(patch, "noise_refiner")
|
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
|
|||||||
27
comfy/ops.py
27
comfy/ops.py
@ -497,14 +497,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if dtype is None:
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
dtype = MixedPrecisionOps._compute_dtype
|
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self._has_bias = bias
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.tensor_class = None
|
self.tensor_class = None
|
||||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
@ -529,14 +530,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
if layer_conf is None:
|
if layer_conf is None:
|
||||||
dtype = self.factory_kwargs["dtype"]
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
|
||||||
if dtype != MixedPrecisionOps._compute_dtype:
|
|
||||||
self.comfy_cast_weights = True
|
|
||||||
if self._has_bias:
|
|
||||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
else:
|
else:
|
||||||
self.quant_format = layer_conf.get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
if not self._full_precision_mm:
|
if not self._full_precision_mm:
|
||||||
@ -566,11 +560,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._has_bias:
|
|
||||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
for param_name in qconfig["parameters"]:
|
for param_name in qconfig["parameters"]:
|
||||||
param_key = f"{prefix}{param_name}"
|
param_key = f"{prefix}{param_name}"
|
||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
@ -592,7 +581,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm:
|
if self._full_precision_mm:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
|
|||||||
@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@ -720,7 +720,7 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
|
|||||||
@ -549,10 +549,8 @@ class VAE:
|
|||||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||||
|
|
||||||
|
|
||||||
# Hunyuan 3d v2 2.0 & 2.1
|
# Hunyuan 3d v2 2.0 & 2.1
|
||||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,6 @@ from . import supported_models_base
|
|||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
|
||||||
from . import diffusers_convert
|
from . import diffusers_convert
|
||||||
import comfy.model_management
|
|
||||||
|
|
||||||
class SD15(supported_models_base.BASE):
|
class SD15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -542,7 +541,7 @@ class SD3(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.SD3
|
latent_format = latent_formats.SD3
|
||||||
|
|
||||||
memory_usage_factor = 1.6
|
memory_usage_factor = 1.2
|
||||||
|
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@ -966,7 +965,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
|||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
super().__init__(unet_config)
|
super().__init__(unet_config)
|
||||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.CosmosPredict2(self, device=device)
|
out = model_base.CosmosPredict2(self, device=device)
|
||||||
@ -1027,15 +1026,9 @@ class ZImage(Lumina2):
|
|||||||
"shift": 3.0,
|
"shift": 3.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 2.0
|
memory_usage_factor = 1.7
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def __init__(self, unet_config):
|
|
||||||
super().__init__(unet_config)
|
|
||||||
if comfy.model_management.extended_fp16_support():
|
|
||||||
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
|
||||||
self.supported_inference_dtypes.insert(1, torch.float16)
|
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
@ -1296,7 +1289,7 @@ class ChromaRadiance(Chroma):
|
|||||||
latent_format = comfy.latent_formats.ChromaRadiance
|
latent_format = comfy.latent_formats.ChromaRadiance
|
||||||
|
|
||||||
# Pixel-space model, no spatial compression for model input.
|
# Pixel-space model, no spatial compression for model input.
|
||||||
memory_usage_factor = 0.044
|
memory_usage_factor = 0.038
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ChromaRadiance(self, device=device)
|
return model_base.ChromaRadiance(self, device=device)
|
||||||
@ -1339,7 +1332,7 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
"shift": 2.6,
|
"shift": 2.6,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 1.95 #TODO
|
memory_usage_factor = 1.65 #TODO
|
||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
@ -1404,7 +1397,7 @@ class HunyuanImage21(HunyuanVideo):
|
|||||||
|
|
||||||
latent_format = latent_formats.HunyuanImage21
|
latent_format = latent_formats.HunyuanImage21
|
||||||
|
|
||||||
memory_usage_factor = 8.7
|
memory_usage_factor = 7.7
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@ -1495,7 +1488,7 @@ class Kandinsky5(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.HunyuanVideo
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
memory_usage_factor = 1.25 #TODO
|
memory_usage_factor = 1.1 #TODO
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@ -1524,7 +1517,7 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
memory_usage_factor = 1.25 #TODO
|
memory_usage_factor = 1.1 #TODO
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Kandinsky5Image(self, device=device)
|
out = model_base.Kandinsky5Image(self, device=device)
|
||||||
|
|||||||
@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
ALWAYS_SAFE_LOAD = True
|
ALWAYS_SAFE_LOAD = True
|
||||||
logging.info("Checkpoint files will always be loaded safely.")
|
logging.info("Checkpoint files will always be loaded safely.")
|
||||||
else:
|
else:
|
||||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
if quant_metadata is not None:
|
if quant_metadata is not None:
|
||||||
layers = quant_metadata["layers"]
|
layers = quant_metadata["layers"]
|
||||||
for k, v in layers.items():
|
for k, v in layers.items():
|
||||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
||||||
|
|
||||||
return state_dict, metadata
|
return state_dict, metadata
|
||||||
|
|||||||
@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
|
|||||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Dict
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
# Default server capabilities
|
# Default server capabilities
|
||||||
SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
"extension": {"manager": {"supports_v4": True}},
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
|||||||
|
|
||||||
|
|
||||||
def get_connection_feature(
|
def get_connection_feature(
|
||||||
sockets_metadata: dict[str, dict[str, Any]],
|
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||||
sid: str,
|
sid: str,
|
||||||
feature_name: str,
|
feature_name: str,
|
||||||
default: Any = False
|
default: Any = False
|
||||||
@ -42,7 +42,7 @@ def get_connection_feature(
|
|||||||
|
|
||||||
|
|
||||||
def supports_feature(
|
def supports_feature(
|
||||||
sockets_metadata: dict[str, dict[str, Any]],
|
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||||
sid: str,
|
sid: str,
|
||||||
feature_name: str
|
feature_name: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -60,7 +60,7 @@ def supports_feature(
|
|||||||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||||
|
|
||||||
|
|
||||||
def get_server_features() -> dict[str, Any]:
|
def get_server_features() -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get the server's feature flags.
|
Get the server's feature flags.
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import NamedTuple
|
from typing import Type, List, NamedTuple
|
||||||
from comfy_api.internal.singleton import ProxiedSingleton
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
from packaging import version as packaging_version
|
from packaging import version as packaging_version
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
|
|||||||
|
|
||||||
class ComfyAPIWithVersion(NamedTuple):
|
class ComfyAPIWithVersion(NamedTuple):
|
||||||
version: str
|
version: str
|
||||||
api_class: type[ComfyAPIBase]
|
api_class: Type[ComfyAPIBase]
|
||||||
|
|
||||||
|
|
||||||
def parse_version(version_str: str) -> packaging_version.Version:
|
def parse_version(version_str: str) -> packaging_version.Version:
|
||||||
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
|
|||||||
return packaging_version.parse(version_str)
|
return packaging_version.parse(version_str)
|
||||||
|
|
||||||
|
|
||||||
registered_versions: list[ComfyAPIWithVersion] = []
|
registered_versions: List[ComfyAPIWithVersion] = []
|
||||||
|
|
||||||
|
|
||||||
def register_versions(versions: list[ComfyAPIWithVersion]):
|
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||||
versions.sort(key=lambda x: parse_version(x.version))
|
versions.sort(key=lambda x: parse_version(x.version))
|
||||||
global registered_versions
|
global registered_versions
|
||||||
registered_versions = versions
|
registered_versions = versions
|
||||||
|
|
||||||
|
|
||||||
def get_all_versions() -> list[ComfyAPIWithVersion]:
|
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||||
"""
|
"""
|
||||||
Returns a list of all registered ComfyAPI versions.
|
Returns a list of all registered ComfyAPI versions.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import os
|
|||||||
import textwrap
|
import textwrap
|
||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, get_origin, get_args, get_type_hints
|
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
class TypeTracker:
|
class TypeTracker:
|
||||||
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
|
|||||||
return result_container["result"]
|
return result_container["result"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
|
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||||
"""
|
"""
|
||||||
Creates a new class with synchronous versions of all async methods.
|
Creates a new class with synchronous versions of all async methods.
|
||||||
|
|
||||||
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generate_imports(
|
def _generate_imports(
|
||||||
cls, async_class: type, type_tracker: TypeTracker
|
cls, async_class: Type, type_tracker: TypeTracker
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Generate import statements for the stub file."""
|
"""Generate import statements for the stub file."""
|
||||||
imports = []
|
imports = []
|
||||||
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
|
|||||||
return imports
|
return imports
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
|
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||||
"""Extract class attributes that are classes themselves."""
|
"""Extract class attributes that are classes themselves."""
|
||||||
class_attributes = []
|
class_attributes = []
|
||||||
|
|
||||||
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
|
|||||||
def _generate_inner_class_stub(
|
def _generate_inner_class_stub(
|
||||||
cls,
|
cls,
|
||||||
name: str,
|
name: str,
|
||||||
attr: type,
|
attr: Type,
|
||||||
indent: str = " ",
|
indent: str = " ",
|
||||||
type_tracker: Optional[TypeTracker] = None,
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
|
|||||||
return processed
|
return processed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
|
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||||
"""
|
"""
|
||||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||||
"""
|
"""
|
||||||
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
|
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||||
"""
|
"""
|
||||||
Creates a sync version of an async class
|
Creates a sync version of an async class
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import TypeVar
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
class SingletonMetaclass(type):
|
class SingletonMetaclass(type):
|
||||||
T = TypeVar("T", bound="SingletonMetaclass")
|
T = TypeVar("T", bound="SingletonMetaclass")
|
||||||
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
|
|||||||
)
|
)
|
||||||
return cls._instances[cls]
|
return cls._instances[cls]
|
||||||
|
|
||||||
def inject_instance(cls: type[T], instance: T) -> None:
|
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||||
assert cls not in SingletonMetaclass._instances, (
|
assert cls not in SingletonMetaclass._instances, (
|
||||||
"Cannot inject instance after first instantiation"
|
"Cannot inject instance after first instantiation"
|
||||||
)
|
)
|
||||||
SingletonMetaclass._instances[cls] = instance
|
SingletonMetaclass._instances[cls] = instance
|
||||||
|
|
||||||
def get_instance(cls: type[T], *args, **kwargs) -> T:
|
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||||
"""
|
"""
|
||||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING
|
from typing import Type, TYPE_CHECKING
|
||||||
from comfy_api.internal import ComfyAPIBase
|
from comfy_api.internal import ComfyAPIBase
|
||||||
from comfy_api.internal.singleton import ProxiedSingleton
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
from comfy_api.internal.async_to_sync import create_sync_class
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||||
|
|
||||||
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
||||||
# create new aliases for io and ui
|
# create new aliases for io and ui
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TypedDict, Optional
|
from typing import TypedDict, List, Optional
|
||||||
|
|
||||||
ImageInput = torch.Tensor
|
ImageInput = torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
|
|||||||
Optional noise mask tensor in the same format as samples.
|
Optional noise mask tensor in the same format as samples.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_index: Optional[list[int]]
|
batch_index: Optional[List[int]]
|
||||||
|
|||||||
@ -774,13 +774,6 @@ class AudioEncoder(ComfyTypeIO):
|
|||||||
class AudioEncoderOutput(ComfyTypeIO):
|
class AudioEncoderOutput(ComfyTypeIO):
|
||||||
Type = Any
|
Type = Any
|
||||||
|
|
||||||
@comfytype(io_type="TRACKS")
|
|
||||||
class Tracks(ComfyTypeIO):
|
|
||||||
class TrackDict(TypedDict):
|
|
||||||
track_path: torch.Tensor
|
|
||||||
track_visibility: torch.Tensor
|
|
||||||
Type = TrackDict
|
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||||
class MultiType:
|
class MultiType:
|
||||||
Type = Any
|
Type = Any
|
||||||
@ -1556,12 +1549,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
|
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
||||||
"""Creates clone of real node class to prevent monkey-patching."""
|
"""Creates clone of real node class to prevent monkey-patching."""
|
||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -1822,7 +1815,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||||||
ui = data["ui"]
|
ui = data["ui"]
|
||||||
if "expand" in data:
|
if "expand" in data:
|
||||||
expand = data["expand"]
|
expand = data["expand"]
|
||||||
return cls(*args, ui=ui, expand=expand)
|
return cls(args=args, ui=ui, expand=expand)
|
||||||
|
|
||||||
def __getitem__(self, index) -> Any:
|
def __getitem__(self, index) -> Any:
|
||||||
return self.args[index]
|
return self.args[index]
|
||||||
@ -1901,7 +1894,6 @@ __all__ = [
|
|||||||
"SEGS",
|
"SEGS",
|
||||||
"AnyType",
|
"AnyType",
|
||||||
"MultiType",
|
"MultiType",
|
||||||
"Tracks",
|
|
||||||
# Dynamic Types
|
# Dynamic Types
|
||||||
"MatchType",
|
"MatchType",
|
||||||
# "DynamicCombo",
|
# "DynamicCombo",
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -82,7 +83,7 @@ class ImageSaveHelper:
|
|||||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||||
if args.disable_metadata or cls is None or not cls.hidden:
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
return None
|
return None
|
||||||
@ -95,7 +96,7 @@ class ImageSaveHelper:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||||
if args.disable_metadata or cls is None or not cls.hidden:
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
return None
|
return None
|
||||||
@ -120,7 +121,7 @@ class ImageSaveHelper:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
|
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||||
"""Creates EXIF metadata bytes for WebP images."""
|
"""Creates EXIF metadata bytes for WebP images."""
|
||||||
exif_data = pil_image.getexif()
|
exif_data = pil_image.getexif()
|
||||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||||
@ -136,7 +137,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_images(
|
def save_images(
|
||||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||||
) -> list[SavedResult]:
|
) -> list[SavedResult]:
|
||||||
"""Saves a batch of images as individual PNG files."""
|
"""Saves a batch of images as individual PNG files."""
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
@ -154,7 +155,7 @@ class ImageSaveHelper:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||||
"""Saves a batch of images and returns a UI object for the node output."""
|
"""Saves a batch of images and returns a UI object for the node output."""
|
||||||
return SavedImages(
|
return SavedImages(
|
||||||
ImageSaveHelper.save_images(
|
ImageSaveHelper.save_images(
|
||||||
@ -168,7 +169,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_animated_png(
|
def save_animated_png(
|
||||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
) -> SavedResult:
|
) -> SavedResult:
|
||||||
"""Saves a batch of images as a single animated PNG."""
|
"""Saves a batch of images as a single animated PNG."""
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
@ -190,7 +191,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_animated_png_ui(
|
def get_save_animated_png_ui(
|
||||||
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
) -> SavedImages:
|
) -> SavedImages:
|
||||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||||
result = ImageSaveHelper.save_animated_png(
|
result = ImageSaveHelper.save_animated_png(
|
||||||
@ -208,7 +209,7 @@ class ImageSaveHelper:
|
|||||||
images,
|
images,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
folder_type: FolderType,
|
folder_type: FolderType,
|
||||||
cls: type[ComfyNode] | None,
|
cls: Type[ComfyNode] | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
lossless: bool,
|
lossless: bool,
|
||||||
quality: int,
|
quality: int,
|
||||||
@ -237,7 +238,7 @@ class ImageSaveHelper:
|
|||||||
def get_save_animated_webp_ui(
|
def get_save_animated_webp_ui(
|
||||||
images,
|
images,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
cls: type[ComfyNode] | None,
|
cls: Type[ComfyNode] | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
lossless: bool,
|
lossless: bool,
|
||||||
quality: int,
|
quality: int,
|
||||||
@ -266,7 +267,7 @@ class AudioSaveHelper:
|
|||||||
audio: dict,
|
audio: dict,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
folder_type: FolderType,
|
folder_type: FolderType,
|
||||||
cls: type[ComfyNode] | None,
|
cls: Type[ComfyNode] | None,
|
||||||
format: str = "flac",
|
format: str = "flac",
|
||||||
quality: str = "128k",
|
quality: str = "128k",
|
||||||
) -> list[SavedResult]:
|
) -> list[SavedResult]:
|
||||||
@ -371,7 +372,7 @@ class AudioSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_audio_ui(
|
def get_save_audio_ui(
|
||||||
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||||
) -> SavedAudios:
|
) -> SavedAudios:
|
||||||
"""Save and instantly wrap for UI."""
|
"""Save and instantly wrap for UI."""
|
||||||
return SavedAudios(
|
return SavedAudios(
|
||||||
@ -387,7 +388,7 @@ class AudioSaveHelper:
|
|||||||
|
|
||||||
|
|
||||||
class PreviewImage(_UIOutput):
|
class PreviewImage(_UIOutput):
|
||||||
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
|
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||||
self.values = ImageSaveHelper.save_images(
|
self.values = ImageSaveHelper.save_images(
|
||||||
image,
|
image,
|
||||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||||
@ -411,7 +412,7 @@ class PreviewMask(PreviewImage):
|
|||||||
|
|
||||||
|
|
||||||
class PreviewAudio(_UIOutput):
|
class PreviewAudio(_UIOutput):
|
||||||
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
|
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||||
self.values = AudioSaveHelper.save_audio(
|
self.values = AudioSaveHelper.save_audio(
|
||||||
audio,
|
audio,
|
||||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||||
|
|||||||
@ -2,8 +2,9 @@ from comfy_api.latest import ComfyAPI_latest
|
|||||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||||
from comfy_api.internal import ComfyAPIBase
|
from comfy_api.internal import ComfyAPIBase
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
supported_versions: list[type[ComfyAPIBase]] = [
|
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||||
ComfyAPI_latest,
|
ComfyAPI_latest,
|
||||||
ComfyAPIAdapter_v0_0_2,
|
ComfyAPIAdapter_v0_0_2,
|
||||||
ComfyAPIAdapter_v0_0_1,
|
ComfyAPIAdapter_v0_0_1,
|
||||||
|
|||||||
@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
|
|||||||
url: str = Field(..., description="URL for generated image")
|
url: str = Field(..., description="URL for generated image")
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResults(BaseModel):
|
class OmniTaskStatusResults(BaseModel):
|
||||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||||
images: list[TaskStatusImageResult] | None = Field(None)
|
images: list[TaskStatusImageResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResponseData(BaseModel):
|
class OmniTaskStatusResponseData(BaseModel):
|
||||||
created_at: int | None = Field(None, description="Task creation time")
|
created_at: int | None = Field(None, description="Task creation time")
|
||||||
updated_at: int | None = Field(None, description="Task update time")
|
updated_at: int | None = Field(None, description="Task update time")
|
||||||
task_status: str | None = None
|
task_status: str | None = None
|
||||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||||
task_id: str | None = Field(None, description="Task ID")
|
task_id: str | None = Field(None, description="Task ID")
|
||||||
task_result: TaskStatusResults | None = Field(None)
|
task_result: OmniTaskStatusResults | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResponse(BaseModel):
|
class OmniTaskStatusResponse(BaseModel):
|
||||||
code: int | None = Field(None, description="Error code")
|
code: int | None = Field(None, description="Error code")
|
||||||
message: str | None = Field(None, description="Error message")
|
message: str | None = Field(None, description="Error message")
|
||||||
request_id: str | None = Field(None, description="Request ID")
|
request_id: str | None = Field(None, description="Request ID")
|
||||||
data: TaskStatusResponseData | None = Field(None)
|
data: OmniTaskStatusResponseData | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniImageParamImage(BaseModel):
|
class OmniImageParamImage(BaseModel):
|
||||||
@ -84,21 +84,3 @@ class OmniProImageRequest(BaseModel):
|
|||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
n: int | None = Field(1, le=9)
|
n: int | None = Field(1, le=9)
|
||||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||||
|
|
||||||
|
|
||||||
class TextToVideoWithAudioRequest(BaseModel):
|
|
||||||
model_name: str = Field(..., description="kling-v2-6")
|
|
||||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
|
||||||
duration: str = Field(..., description="'5' or '10'")
|
|
||||||
prompt: str = Field(...)
|
|
||||||
mode: str = Field("pro")
|
|
||||||
sound: str = Field(..., description="'on' or 'off'")
|
|
||||||
|
|
||||||
|
|
||||||
class ImageToVideoWithAudioRequest(BaseModel):
|
|
||||||
model_name: str = Field(..., description="kling-v2-6")
|
|
||||||
image: str = Field(...)
|
|
||||||
duration: str = Field(..., description="'5' or '10'")
|
|
||||||
prompt: str = Field(...)
|
|
||||||
mode: str = Field("pro")
|
|
||||||
sound: str = Field(..., description="'on' or 'off'")
|
|
||||||
|
|||||||
@ -1,52 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class Datum2(BaseModel):
|
|
||||||
b64_json: str | None = Field(None, description="Base64 encoded image data")
|
|
||||||
revised_prompt: str | None = Field(None, description="Revised prompt")
|
|
||||||
url: str | None = Field(None, description="URL of the image")
|
|
||||||
|
|
||||||
|
|
||||||
class InputTokensDetails(BaseModel):
|
|
||||||
image_tokens: int | None = None
|
|
||||||
text_tokens: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
|
||||||
input_tokens: int | None = None
|
|
||||||
input_tokens_details: InputTokensDetails | None = None
|
|
||||||
output_tokens: int | None = None
|
|
||||||
total_tokens: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImageGenerationResponse(BaseModel):
|
|
||||||
data: list[Datum2] | None = None
|
|
||||||
usage: Usage | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImageEditRequest(BaseModel):
|
|
||||||
background: str | None = Field(None, description="Background transparency")
|
|
||||||
model: str = Field(...)
|
|
||||||
moderation: str | None = Field(None)
|
|
||||||
n: int | None = Field(None, description="The number of images to generate")
|
|
||||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
|
||||||
output_format: str | None = Field(None)
|
|
||||||
prompt: str = Field(...)
|
|
||||||
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
|
||||||
size: str | None = Field(None, description="Size of the output image")
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImageGenerationRequest(BaseModel):
|
|
||||||
background: str | None = Field(None, description="Background transparency")
|
|
||||||
model: str | None = Field(None)
|
|
||||||
moderation: str | None = Field(None)
|
|
||||||
n: int | None = Field(
|
|
||||||
None,
|
|
||||||
description="The number of images to generate.",
|
|
||||||
)
|
|
||||||
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
|
||||||
output_format: str | None = Field(None)
|
|
||||||
prompt: str = Field(...)
|
|
||||||
quality: str | None = Field(None, description="The quality of the generated image")
|
|
||||||
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
|
||||||
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
|
|
||||||
100
comfy_api_nodes/apis/pika_api.py
Normal file
100
comfy_api_nodes/apis/pika_api.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Pikaffect(str, Enum):
|
||||||
|
Cake_ify = "Cake-ify"
|
||||||
|
Crumble = "Crumble"
|
||||||
|
Crush = "Crush"
|
||||||
|
Decapitate = "Decapitate"
|
||||||
|
Deflate = "Deflate"
|
||||||
|
Dissolve = "Dissolve"
|
||||||
|
Explode = "Explode"
|
||||||
|
Eye_pop = "Eye-pop"
|
||||||
|
Inflate = "Inflate"
|
||||||
|
Levitate = "Levitate"
|
||||||
|
Melt = "Melt"
|
||||||
|
Peel = "Peel"
|
||||||
|
Poke = "Poke"
|
||||||
|
Squish = "Squish"
|
||||||
|
Ta_da = "Ta-da"
|
||||||
|
Tear = "Tear"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||||
|
duration: Optional[int] = Field(5)
|
||||||
|
ingredientsMode: str = Field(...)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = Field('1080p')
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaGenerateResponse(BaseModel):
|
||||||
|
video_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||||
|
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(
|
||||||
|
1.7777777777777777,
|
||||||
|
description='Aspect ratio (width / height)',
|
||||||
|
ge=0.4,
|
||||||
|
le=2.5,
|
||||||
|
)
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
pikaffect: Optional[str] = None
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
modifyRegionRoi: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaStatusEnum(str, Enum):
|
||||||
|
queued = "queued"
|
||||||
|
started = "started"
|
||||||
|
finished = "finished"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaVideoResponse(BaseModel):
|
||||||
|
id: str = Field(...)
|
||||||
|
progress: Optional[int] = Field(None)
|
||||||
|
status: PikaStatusEnum
|
||||||
|
url: Optional[str] = Field(None)
|
||||||
@ -5,17 +5,11 @@ from typing import Optional, List, Dict, Any, Union
|
|||||||
from pydantic import BaseModel, Field, RootModel
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
class TripoModelVersion(str, Enum):
|
class TripoModelVersion(str, Enum):
|
||||||
v3_0_20250812 = 'v3.0-20250812'
|
|
||||||
v2_5_20250123 = 'v2.5-20250123'
|
v2_5_20250123 = 'v2.5-20250123'
|
||||||
v2_0_20240919 = 'v2.0-20240919'
|
v2_0_20240919 = 'v2.0-20240919'
|
||||||
v1_4_20240625 = 'v1.4-20240625'
|
v1_4_20240625 = 'v1.4-20240625'
|
||||||
|
|
||||||
|
|
||||||
class TripoGeometryQuality(str, Enum):
|
|
||||||
standard = 'standard'
|
|
||||||
detailed = 'detailed'
|
|
||||||
|
|
||||||
|
|
||||||
class TripoTextureQuality(str, Enum):
|
class TripoTextureQuality(str, Enum):
|
||||||
standard = 'standard'
|
standard = 'standard'
|
||||||
detailed = 'detailed'
|
detailed = 'detailed'
|
||||||
@ -67,20 +61,14 @@ class TripoSpec(str, Enum):
|
|||||||
class TripoAnimation(str, Enum):
|
class TripoAnimation(str, Enum):
|
||||||
IDLE = "preset:idle"
|
IDLE = "preset:idle"
|
||||||
WALK = "preset:walk"
|
WALK = "preset:walk"
|
||||||
RUN = "preset:run"
|
|
||||||
DIVE = "preset:dive"
|
|
||||||
CLIMB = "preset:climb"
|
CLIMB = "preset:climb"
|
||||||
JUMP = "preset:jump"
|
JUMP = "preset:jump"
|
||||||
|
RUN = "preset:run"
|
||||||
SLASH = "preset:slash"
|
SLASH = "preset:slash"
|
||||||
SHOOT = "preset:shoot"
|
SHOOT = "preset:shoot"
|
||||||
HURT = "preset:hurt"
|
HURT = "preset:hurt"
|
||||||
FALL = "preset:fall"
|
FALL = "preset:fall"
|
||||||
TURN = "preset:turn"
|
TURN = "preset:turn"
|
||||||
QUADRUPED_WALK = "preset:quadruped:walk"
|
|
||||||
HEXAPOD_WALK = "preset:hexapod:walk"
|
|
||||||
OCTOPOD_WALK = "preset:octopod:walk"
|
|
||||||
SERPENTINE_MARCH = "preset:serpentine:march"
|
|
||||||
AQUATIC_MARCH = "preset:aquatic:march"
|
|
||||||
|
|
||||||
class TripoStylizeStyle(str, Enum):
|
class TripoStylizeStyle(str, Enum):
|
||||||
LEGO = "lego"
|
LEGO = "lego"
|
||||||
@ -117,11 +105,6 @@ class TripoTaskStatus(str, Enum):
|
|||||||
BANNED = "banned"
|
BANNED = "banned"
|
||||||
EXPIRED = "expired"
|
EXPIRED = "expired"
|
||||||
|
|
||||||
class TripoFbxPreset(str, Enum):
|
|
||||||
BLENDER = "blender"
|
|
||||||
MIXAMO = "mixamo"
|
|
||||||
_3DSMAX = "3dsmax"
|
|
||||||
|
|
||||||
class TripoFileTokenReference(BaseModel):
|
class TripoFileTokenReference(BaseModel):
|
||||||
type: Optional[str] = Field(None, description='The type of the reference')
|
type: Optional[str] = Field(None, description='The type of the reference')
|
||||||
file_token: str
|
file_token: str
|
||||||
@ -159,7 +142,6 @@ class TripoTextToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
|
||||||
style: Optional[TripoStyle] = None
|
style: Optional[TripoStyle] = None
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||||
@ -174,7 +156,6 @@ class TripoImageToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
|
||||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
@ -192,7 +173,6 @@ class TripoMultiviewToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
|
||||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||||
@ -239,24 +219,14 @@ class TripoConvertModelRequest(BaseModel):
|
|||||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
||||||
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
|
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
||||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
|
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
||||||
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
|
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
||||||
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
|
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
||||||
texture_size: Optional[int] = Field(None, description='The size of the texture')
|
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
||||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||||
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
|
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
||||||
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
|
|
||||||
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
|
|
||||||
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
|
|
||||||
bake: Optional[bool] = Field(None, description='Whether to bake the model')
|
|
||||||
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
|
|
||||||
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
|
|
||||||
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
|
|
||||||
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
|
|
||||||
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
|
|
||||||
|
|
||||||
|
|
||||||
class TripoTaskRequest(RootModel):
|
class TripoTaskRequest(RootModel):
|
||||||
root: Union[
|
root: Union[
|
||||||
|
|||||||
@ -50,7 +50,6 @@ from comfy_api_nodes.apis import (
|
|||||||
KlingSingleImageEffectModelName,
|
KlingSingleImageEffectModelName,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.kling_api import (
|
from comfy_api_nodes.apis.kling_api import (
|
||||||
ImageToVideoWithAudioRequest,
|
|
||||||
OmniImageParamImage,
|
OmniImageParamImage,
|
||||||
OmniParamImage,
|
OmniParamImage,
|
||||||
OmniParamVideo,
|
OmniParamVideo,
|
||||||
@ -58,8 +57,7 @@ from comfy_api_nodes.apis.kling_api import (
|
|||||||
OmniProImageRequest,
|
OmniProImageRequest,
|
||||||
OmniProReferences2VideoRequest,
|
OmniProReferences2VideoRequest,
|
||||||
OmniProText2VideoRequest,
|
OmniProText2VideoRequest,
|
||||||
TaskStatusResponse,
|
OmniTaskStatusResponse,
|
||||||
TextToVideoWithAudioRequest,
|
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -105,6 +103,10 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
|
|||||||
|
|
||||||
|
|
||||||
MODE_TEXT2VIDEO = {
|
MODE_TEXT2VIDEO = {
|
||||||
|
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||||
|
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
|
||||||
|
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||||
|
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
|
||||||
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
|
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
|
||||||
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
|
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
|
||||||
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
|
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
|
||||||
@ -125,6 +127,8 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
|
|||||||
|
|
||||||
|
|
||||||
MODE_START_END_FRAME = {
|
MODE_START_END_FRAME = {
|
||||||
|
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||||
|
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||||
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
|
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
|
||||||
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
||||||
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
||||||
@ -238,7 +242,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
|
|||||||
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||||
|
|
||||||
|
|
||||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
|
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
|
||||||
if response.code:
|
if response.code:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
@ -246,7 +250,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
|
|||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
max_poll_attempts=160,
|
max_poll_attempts=160,
|
||||||
)
|
)
|
||||||
@ -479,12 +483,12 @@ async def execute_image2video(
|
|||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||||
response_model=KlingImage2VideoResponse,
|
response_model=KlingImage2VideoResponse,
|
||||||
estimated_duration=AVERAGE_DURATION_I2V,
|
estimated_duration=AVERAGE_DURATION_I2V,
|
||||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
@ -748,7 +752,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"mode",
|
"mode",
|
||||||
options=modes,
|
options=modes,
|
||||||
default=modes[8],
|
default=modes[4],
|
||||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -830,7 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
data=OmniProText2VideoRequest(
|
data=OmniProText2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -925,7 +929,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
data=OmniProFirstLastFrameRequest(
|
data=OmniProFirstLastFrameRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -993,7 +997,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1077,7 +1081,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1158,7 +1162,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1233,7 +1237,7 @@ class OmniProImageNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
data=OmniProImageRequest(
|
data=OmniProImageRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1249,7 +1253,7 @@ class OmniProImageNode(IO.ComfyNode):
|
|||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||||
response_model=TaskStatusResponse,
|
response_model=OmniTaskStatusResponse,
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||||
@ -1324,8 +1328,9 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
|||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="KlingImage2VideoNode",
|
node_id="KlingImage2VideoNode",
|
||||||
display_name="Kling Image(First Frame) to Video",
|
display_name="Kling Image to Video",
|
||||||
category="api node/video/Kling",
|
category="api node/video/Kling",
|
||||||
|
description="Kling Image to Video Node",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||||
@ -1483,7 +1488,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"mode",
|
"mode",
|
||||||
options=modes,
|
options=modes,
|
||||||
default=modes[6],
|
default=modes[8],
|
||||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -1946,7 +1951,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model_name",
|
"model_name",
|
||||||
options=[i.value for i in KlingImageGenModelName],
|
options=[i.value for i in KlingImageGenModelName],
|
||||||
default="kling-v2",
|
default="kling-v1",
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
@ -2029,136 +2034,6 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await image_result_to_node_output(images))
|
return IO.NodeOutput(await image_result_to_node_output(images))
|
||||||
|
|
||||||
|
|
||||||
class TextToVideoWithAudio(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="KlingTextToVideoWithAudio",
|
|
||||||
display_name="Kling Text to Video with Audio",
|
|
||||||
category="api node/video/Kling",
|
|
||||||
inputs=[
|
|
||||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
|
||||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
|
||||||
IO.Combo.Input("mode", options=["pro"]),
|
|
||||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
|
||||||
IO.Combo.Input("duration", options=[5, 10]),
|
|
||||||
IO.Boolean.Input("generate_audio", default=True),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
IO.Video.Output(),
|
|
||||||
],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
model_name: str,
|
|
||||||
prompt: str,
|
|
||||||
mode: str,
|
|
||||||
aspect_ratio: str,
|
|
||||||
duration: int,
|
|
||||||
generate_audio: bool,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
|
||||||
response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
data=TextToVideoWithAudioRequest(
|
|
||||||
model_name=model_name,
|
|
||||||
prompt=prompt,
|
|
||||||
mode=mode,
|
|
||||||
aspect_ratio=aspect_ratio,
|
|
||||||
duration=str(duration),
|
|
||||||
sound="on" if generate_audio else "off",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if response.code:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
|
||||||
)
|
|
||||||
final_response = await poll_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
|
||||||
|
|
||||||
|
|
||||||
class ImageToVideoWithAudio(IO.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> IO.Schema:
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="KlingImageToVideoWithAudio",
|
|
||||||
display_name="Kling Image(First Frame) to Video with Audio",
|
|
||||||
category="api node/video/Kling",
|
|
||||||
inputs=[
|
|
||||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
|
||||||
IO.Image.Input("start_frame"),
|
|
||||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
|
||||||
IO.Combo.Input("mode", options=["pro"]),
|
|
||||||
IO.Combo.Input("duration", options=[5, 10]),
|
|
||||||
IO.Boolean.Input("generate_audio", default=True),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
IO.Video.Output(),
|
|
||||||
],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
model_name: str,
|
|
||||||
start_frame: Input.Image,
|
|
||||||
prompt: str,
|
|
||||||
mode: str,
|
|
||||||
duration: int,
|
|
||||||
generate_audio: bool,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_string(prompt, min_length=1, max_length=2500)
|
|
||||||
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
|
||||||
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
|
||||||
response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
data=ImageToVideoWithAudioRequest(
|
|
||||||
model_name=model_name,
|
|
||||||
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
|
|
||||||
prompt=prompt,
|
|
||||||
mode=mode,
|
|
||||||
duration=str(duration),
|
|
||||||
sound="on" if generate_audio else "off",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if response.code:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
|
||||||
)
|
|
||||||
final_response = await poll_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
|
||||||
|
|
||||||
|
|
||||||
class KlingExtension(ComfyExtension):
|
class KlingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -2181,9 +2056,7 @@ class KlingExtension(ComfyExtension):
|
|||||||
OmniProImageToVideoNode,
|
OmniProImageToVideoNode,
|
||||||
OmniProVideoToVideoNode,
|
OmniProVideoToVideoNode,
|
||||||
OmniProEditVideoNode,
|
OmniProEditVideoNode,
|
||||||
OmniProImageNode,
|
# OmniProImageNode, # need support from backend
|
||||||
TextToVideoWithAudio,
|
|
||||||
ImageToVideoWithAudio,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,45 +1,46 @@
|
|||||||
import base64
|
from io import BytesIO
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from inspect import cleandoc
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import folder_paths
|
||||||
|
import base64
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
CreateModelResponseProperties,
|
OpenAIImageGenerationRequest,
|
||||||
Detail,
|
OpenAIImageEditRequest,
|
||||||
InputContent,
|
OpenAIImageGenerationResponse,
|
||||||
InputFileContent,
|
|
||||||
InputImageContent,
|
|
||||||
InputMessage,
|
|
||||||
InputMessageContentList,
|
|
||||||
InputTextContent,
|
|
||||||
Item,
|
|
||||||
OpenAICreateResponse,
|
OpenAICreateResponse,
|
||||||
OpenAIResponse,
|
OpenAIResponse,
|
||||||
|
CreateModelResponseProperties,
|
||||||
|
Item,
|
||||||
OutputContent,
|
OutputContent,
|
||||||
|
InputImageContent,
|
||||||
|
Detail,
|
||||||
|
InputTextContent,
|
||||||
|
InputMessage,
|
||||||
|
InputMessageContentList,
|
||||||
|
InputContent,
|
||||||
|
InputFileContent,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.openai_api import (
|
|
||||||
OpenAIImageEditRequest,
|
|
||||||
OpenAIImageGenerationRequest,
|
|
||||||
OpenAIImageGenerationResponse,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
|
||||||
download_url_to_bytesio,
|
|
||||||
downscale_image_tensor,
|
downscale_image_tensor,
|
||||||
poll_op,
|
download_url_to_bytesio,
|
||||||
sync_op,
|
|
||||||
tensor_to_base64_string,
|
|
||||||
text_filepath_to_data_uri,
|
|
||||||
validate_string,
|
validate_string,
|
||||||
|
tensor_to_base64_string,
|
||||||
|
ApiEndpoint,
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
|
text_filepath_to_data_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||||
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||||
|
|
||||||
@ -97,6 +98,9 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIDalle2(IO.ComfyNode):
|
class OpenAIDalle2(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -104,7 +108,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||||||
node_id="OpenAIDalle2",
|
node_id="OpenAIDalle2",
|
||||||
display_name="OpenAI DALL·E 2",
|
display_name="OpenAI DALL·E 2",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -230,6 +234,9 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIDalle3(IO.ComfyNode):
|
class OpenAIDalle3(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -237,7 +244,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||||||
node_id="OpenAIDalle3",
|
node_id="OpenAIDalle3",
|
||||||
display_name="OpenAI DALL·E 3",
|
display_name="OpenAI DALL·E 3",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -319,16 +326,10 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||||
|
|
||||||
|
|
||||||
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
|
|
||||||
# https://platform.openai.com/docs/pricing
|
|
||||||
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
|
|
||||||
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTImage1(IO.ComfyNode):
|
class OpenAIGPTImage1(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -336,13 +337,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
node_id="OpenAIGPTImage1",
|
node_id="OpenAIGPTImage1",
|
||||||
display_name="OpenAI GPT Image 1",
|
display_name="OpenAI GPT Image 1",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
default="",
|
default="",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="Text prompt for GPT Image",
|
tooltip="Text prompt for GPT Image 1",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -364,8 +365,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"background",
|
"background",
|
||||||
default="auto",
|
default="opaque",
|
||||||
options=["auto", "opaque", "transparent"],
|
options=["opaque", "transparent"],
|
||||||
tooltip="Return image with or without background",
|
tooltip="Return image with or without background",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -396,11 +397,6 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
|
||||||
"model",
|
|
||||||
options=["gpt-image-1", "gpt-image-1.5"],
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -416,34 +412,32 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
prompt: str,
|
prompt,
|
||||||
seed: int = 0,
|
seed=0,
|
||||||
quality: str = "low",
|
quality="low",
|
||||||
background: str = "opaque",
|
background="opaque",
|
||||||
image: Input.Image | None = None,
|
image=None,
|
||||||
mask: Input.Image | None = None,
|
mask=None,
|
||||||
n: int = 1,
|
n=1,
|
||||||
size: str = "1024x1024",
|
size="1024x1024",
|
||||||
model: str = "gpt-image-1",
|
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
model = "gpt-image-1"
|
||||||
if mask is not None and image is None:
|
path = "/proxy/openai/images/generations"
|
||||||
raise ValueError("Cannot use a mask without an input image")
|
content_type = "application/json"
|
||||||
|
request_class = OpenAIImageGenerationRequest
|
||||||
if model == "gpt-image-1":
|
files = []
|
||||||
price_extractor = calculate_tokens_price_image_1
|
|
||||||
elif model == "gpt-image-1.5":
|
|
||||||
price_extractor = calculate_tokens_price_image_1_5
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model: {model}")
|
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
files = []
|
path = "/proxy/openai/images/edits"
|
||||||
|
request_class = OpenAIImageEditRequest
|
||||||
|
content_type = "multipart/form-data"
|
||||||
|
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
single_image = image[i: i + 1]
|
single_image = image[i : i + 1]
|
||||||
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
|
scaled_image = downscale_image_tensor(single_image).squeeze()
|
||||||
|
|
||||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||||
img = Image.fromarray(image_np)
|
img = Image.fromarray(image_np)
|
||||||
@ -456,59 +450,44 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if image.shape[0] != 1:
|
if image is None:
|
||||||
raise Exception("Cannot use a mask with multiple image")
|
raise Exception("Cannot use a mask without an input image")
|
||||||
if mask.shape[1:] != image.shape[1:-1]:
|
if image.shape[0] != 1:
|
||||||
raise Exception("Mask and Image must be the same size")
|
raise Exception("Cannot use a mask with multiple image")
|
||||||
_, height, width = mask.shape
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
raise Exception("Mask and Image must be the same size")
|
||||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
batch, height, width = mask.shape
|
||||||
|
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||||
|
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||||
|
|
||||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
|
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
||||||
|
|
||||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||||
mask_img = Image.fromarray(mask_np)
|
mask_img = Image.fromarray(mask_np)
|
||||||
mask_img_byte_arr = BytesIO()
|
mask_img_byte_arr = BytesIO()
|
||||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||||
mask_img_byte_arr.seek(0)
|
mask_img_byte_arr.seek(0)
|
||||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||||
|
|
||||||
|
# Build the operation
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=path, method="POST"),
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
data=request_class(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
background=background,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
),
|
||||||
|
files=files if files else None,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
|
||||||
response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
|
||||||
response_model=OpenAIImageGenerationResponse,
|
|
||||||
data=OpenAIImageEditRequest(
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
quality=quality,
|
|
||||||
background=background,
|
|
||||||
n=n,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
moderation="low",
|
|
||||||
),
|
|
||||||
content_type="multipart/form-data",
|
|
||||||
files=files,
|
|
||||||
price_extractor=price_extractor,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
|
||||||
response_model=OpenAIImageGenerationResponse,
|
|
||||||
data=OpenAIImageGenerationRequest(
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
quality=quality,
|
|
||||||
background=background,
|
|
||||||
n=n,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
moderation="low",
|
|
||||||
),
|
|
||||||
price_extractor=price_extractor,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
575
comfy_api_nodes/nodes_pika.py
Normal file
575
comfy_api_nodes/nodes_pika.py
Normal file
@ -0,0 +1,575 @@
|
|||||||
|
"""
|
||||||
|
Pika x ComfyUI API Nodes
|
||||||
|
|
||||||
|
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||||
|
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
validate_string,
|
||||||
|
download_url_to_video_output,
|
||||||
|
tensor_to_bytesio,
|
||||||
|
ApiEndpoint,
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||||
|
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||||
|
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
||||||
|
|
||||||
|
PIKA_API_VERSION = "2.2"
|
||||||
|
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
||||||
|
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
||||||
|
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
||||||
|
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||||
|
|
||||||
|
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_task(
|
||||||
|
task_id: str,
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||||
|
response_model=pika_defs.PikaVideoResponse,
|
||||||
|
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||||
|
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||||
|
estimated_duration=60,
|
||||||
|
max_poll_attempts=240,
|
||||||
|
)
|
||||||
|
if not final_response.url:
|
||||||
|
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
video_url = final_response.url
|
||||||
|
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_inputs_types() -> list[IO.Input]:
|
||||||
|
"""Get the base required inputs types common to all Pika nodes."""
|
||||||
|
return [
|
||||||
|
IO.String.Input("prompt_text", multiline=True),
|
||||||
|
IO.String.Input("negative_prompt", multiline=True),
|
||||||
|
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||||
|
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||||
|
IO.Combo.Input("duration", options=[5, 10], default=5),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PikaImageToVideo(IO.ComfyNode):
|
||||||
|
"""Pika 2.2 Image to Video Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="PikaImageToVideoNode2_2",
|
||||||
|
display_name="Pika Image to Video",
|
||||||
|
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
||||||
|
category="api node/video/Pika",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image", tooltip="The image to convert to video"),
|
||||||
|
*get_base_inputs_types(),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
|
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
initial_operation = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||||
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
|
data=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
)
|
||||||
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaTextToVideoNode(IO.ComfyNode):
|
||||||
|
"""Pika Text2Video v2.2 Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="PikaTextToVideoNode2_2",
|
||||||
|
display_name="Pika Text to Video",
|
||||||
|
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
||||||
|
category="api node/video/Pika",
|
||||||
|
inputs=[
|
||||||
|
*get_base_inputs_types(),
|
||||||
|
IO.Float.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
step=0.001,
|
||||||
|
min=0.4,
|
||||||
|
max=2.5,
|
||||||
|
default=1.7777777777777777,
|
||||||
|
tooltip="Aspect ratio (width / height)",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
aspect_ratio: float,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
initial_operation = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||||
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
|
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
aspectRatio=aspect_ratio,
|
||||||
|
),
|
||||||
|
content_type="application/x-www-form-urlencoded",
|
||||||
|
)
|
||||||
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaScenes(IO.ComfyNode):
|
||||||
|
"""PikaScenes v2.2 Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="PikaScenesV2_2",
|
||||||
|
display_name="Pika Scenes (Video Image Composition)",
|
||||||
|
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
||||||
|
category="api node/video/Pika",
|
||||||
|
inputs=[
|
||||||
|
*get_base_inputs_types(),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"ingredients_mode",
|
||||||
|
options=["creative", "precise"],
|
||||||
|
default="creative",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
step=0.001,
|
||||||
|
min=0.4,
|
||||||
|
max=2.5,
|
||||||
|
default=1.7777777777777777,
|
||||||
|
tooltip="Aspect ratio (width / height)",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"image_ingredient_1",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"image_ingredient_2",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"image_ingredient_3",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"image_ingredient_4",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"image_ingredient_5",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
ingredients_mode: str,
|
||||||
|
aspect_ratio: float,
|
||||||
|
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||||
|
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
all_image_bytes_io = []
|
||||||
|
for image in [
|
||||||
|
image_ingredient_1,
|
||||||
|
image_ingredient_2,
|
||||||
|
image_ingredient_3,
|
||||||
|
image_ingredient_4,
|
||||||
|
image_ingredient_5,
|
||||||
|
]:
|
||||||
|
if image is not None:
|
||||||
|
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||||
|
|
||||||
|
pika_files = [
|
||||||
|
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||||
|
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||||
|
]
|
||||||
|
|
||||||
|
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||||
|
ingredientsMode=ingredients_mode,
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
aspectRatio=aspect_ratio,
|
||||||
|
)
|
||||||
|
initial_operation = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||||
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
|
data=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
)
|
||||||
|
|
||||||
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class PikAdditionsNode(IO.ComfyNode):
|
||||||
|
"""Pika Pikadditions Node. Add an image into a video."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Pikadditions",
|
||||||
|
display_name="Pikadditions (Video Object Insertion)",
|
||||||
|
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
||||||
|
category="api node/video/Pika",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input("video", tooltip="The video to add an image to."),
|
||||||
|
IO.Image.Input("image", tooltip="The image to add to the video."),
|
||||||
|
IO.String.Input("prompt_text", multiline=True),
|
||||||
|
IO.String.Input("negative_prompt", multiline=True),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: VideoInput,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
video_bytes_io = BytesIO()
|
||||||
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
|
pika_files = {
|
||||||
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
|
"image": ("image.png", image_bytes_io, "image/png"),
|
||||||
|
}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
initial_operation = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||||
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
|
data=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
)
|
||||||
|
|
||||||
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaSwapsNode(IO.ComfyNode):
|
||||||
|
"""Pika Pikaswaps Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Pikaswaps",
|
||||||
|
display_name="Pika Swaps (Video Object Replacement)",
|
||||||
|
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
||||||
|
category="api node/video/Pika",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="The image used to replace the masked object in the video.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Mask.Input(
|
||||||
|
"mask",
|
||||||
|
tooltip="Use the mask to define areas in the video to replace.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.String.Input("prompt_text", multiline=True, optional=True),
|
||||||
|
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||||
|
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||||
|
IO.String.Input(
|
||||||
|
"region_to_modify",
|
||||||
|
multiline=True,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Plaintext description of the object / region to modify.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: VideoInput,
|
||||||
|
image: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
prompt_text: str = "",
|
||||||
|
negative_prompt: str = "",
|
||||||
|
seed: int = 0,
|
||||||
|
region_to_modify: str = "",
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
video_bytes_io = BytesIO()
|
||||||
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
pika_files = {
|
||||||
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
|
}
|
||||||
|
if mask is not None:
|
||||||
|
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||||
|
if image is not None:
|
||||||
|
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||||
|
|
||||||
|
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||||
|
)
|
||||||
|
initial_operation = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||||
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
|
data=pika_request_data,
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
)
|
||||||
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaffectsNode(IO.ComfyNode):
|
||||||
|
"""Pika Pikaffects Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Pikaffects",
|
||||||
|
display_name="Pikaffects (Video Effects)",
|
||||||
|
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
||||||
|
category="api node/video/Pika",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||||
|
),
|
||||||
|
IO.String.Input("prompt_text", multiline=True),
|
||||||
|
IO.String.Input("negative_prompt", multiline=True),
|
||||||
|
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image: torch.Tensor,
|
||||||
|
pikaffect: str,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
initial_operation = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||||
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
|
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||||
|
pikaffect=pikaffect,
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
)
|
||||||
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||||
|
"""PikaFrames v2.2 Node."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="PikaStartEndFrameNode2_2",
|
||||||
|
display_name="Pika Start and End Frame to Video",
|
||||||
|
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
||||||
|
category="api node/video/Pika",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
||||||
|
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
||||||
|
*get_base_inputs_types(),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image_start: torch.Tensor,
|
||||||
|
image_end: torch.Tensor,
|
||||||
|
prompt_text: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
duration: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
||||||
|
pika_files = [
|
||||||
|
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||||
|
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||||
|
]
|
||||||
|
initial_operation = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||||
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
|
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||||
|
promptText=prompt_text,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
duration=duration,
|
||||||
|
),
|
||||||
|
files=pika_files,
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
)
|
||||||
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaApiNodesExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
PikaImageToVideo,
|
||||||
|
PikaTextToVideoNode,
|
||||||
|
PikaScenes,
|
||||||
|
PikAdditionsNode,
|
||||||
|
PikaSwapsNode,
|
||||||
|
PikaffectsNode,
|
||||||
|
PikaStartEndFrameNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> PikaApiNodesExtension:
|
||||||
|
return PikaApiNodesExtension()
|
||||||
@ -102,9 +102,8 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
IO.Int.Input("model_seed", default=42, optional=True),
|
IO.Int.Input("model_seed", default=42, optional=True),
|
||||||
IO.Int.Input("texture_seed", default=42, optional=True),
|
IO.Int.Input("texture_seed", default=42, optional=True),
|
||||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
|
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -132,7 +131,6 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
model_seed: Optional[int] = None,
|
model_seed: Optional[int] = None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
geometry_quality: Optional[str] = None,
|
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
@ -156,7 +154,6 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit,
|
||||||
geometry_quality=geometry_quality,
|
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
@ -197,7 +194,6 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -224,7 +220,6 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
orientation=None,
|
orientation=None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
geometry_quality: Optional[str] = None,
|
|
||||||
texture_alignment: Optional[str] = None,
|
texture_alignment: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
@ -251,7 +246,6 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
pbr=pbr,
|
pbr=pbr,
|
||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
orientation=orientation,
|
orientation=orientation,
|
||||||
geometry_quality=geometry_quality,
|
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
@ -301,7 +295,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -330,7 +323,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
model_seed: Optional[int] = None,
|
model_seed: Optional[int] = None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
geometry_quality: Optional[str] = None,
|
|
||||||
texture_alignment: Optional[str] = None,
|
texture_alignment: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
@ -367,7 +359,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
geometry_quality=geometry_quality,
|
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
@ -517,8 +508,6 @@ class TripoRetargetNode(IO.ComfyNode):
|
|||||||
options=[
|
options=[
|
||||||
"preset:idle",
|
"preset:idle",
|
||||||
"preset:walk",
|
"preset:walk",
|
||||||
"preset:run",
|
|
||||||
"preset:dive",
|
|
||||||
"preset:climb",
|
"preset:climb",
|
||||||
"preset:jump",
|
"preset:jump",
|
||||||
"preset:slash",
|
"preset:slash",
|
||||||
@ -526,11 +515,6 @@ class TripoRetargetNode(IO.ComfyNode):
|
|||||||
"preset:hurt",
|
"preset:hurt",
|
||||||
"preset:fall",
|
"preset:fall",
|
||||||
"preset:turn",
|
"preset:turn",
|
||||||
"preset:quadruped:walk",
|
|
||||||
"preset:hexapod:walk",
|
|
||||||
"preset:octopod:walk",
|
|
||||||
"preset:serpentine:march",
|
|
||||||
"preset:aquatic:march"
|
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -579,7 +563,7 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
"face_limit",
|
"face_limit",
|
||||||
default=-1,
|
default=-1,
|
||||||
min=-1,
|
min=-1,
|
||||||
max=2000000,
|
max=500000,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -595,40 +579,6 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
default="JPEG",
|
default="JPEG",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Boolean.Input("force_symmetry", default=False, optional=True),
|
|
||||||
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
|
|
||||||
IO.Float.Input(
|
|
||||||
"flatten_bottom_threshold",
|
|
||||||
default=0.0,
|
|
||||||
min=0.0,
|
|
||||||
max=1.0,
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
|
|
||||||
IO.Float.Input(
|
|
||||||
"scale_factor",
|
|
||||||
default=1.0,
|
|
||||||
min=0.0,
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Boolean.Input("with_animation", default=False, optional=True),
|
|
||||||
IO.Boolean.Input("pack_uv", default=False, optional=True),
|
|
||||||
IO.Boolean.Input("bake", default=False, optional=True),
|
|
||||||
IO.String.Input("part_names", default="", optional=True), # comma-separated list
|
|
||||||
IO.Combo.Input(
|
|
||||||
"fbx_preset",
|
|
||||||
options=["blender", "mixamo", "3dsmax"],
|
|
||||||
default="blender",
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
|
|
||||||
IO.Combo.Input(
|
|
||||||
"export_orientation",
|
|
||||||
options=["align_image", "default"],
|
|
||||||
default="default",
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Boolean.Input("animate_in_place", default=False, optional=True),
|
|
||||||
],
|
],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -654,31 +604,12 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
original_model_task_id,
|
original_model_task_id,
|
||||||
format: str,
|
format: str,
|
||||||
quad: bool,
|
quad: bool,
|
||||||
force_symmetry: bool,
|
|
||||||
face_limit: int,
|
face_limit: int,
|
||||||
flatten_bottom: bool,
|
|
||||||
flatten_bottom_threshold: float,
|
|
||||||
texture_size: int,
|
texture_size: int,
|
||||||
texture_format: str,
|
texture_format: str,
|
||||||
pivot_to_center_bottom: bool,
|
|
||||||
scale_factor: float,
|
|
||||||
with_animation: bool,
|
|
||||||
pack_uv: bool,
|
|
||||||
bake: bool,
|
|
||||||
part_names: str,
|
|
||||||
fbx_preset: str,
|
|
||||||
export_vertex_colors: bool,
|
|
||||||
export_orientation: str,
|
|
||||||
animate_in_place: bool,
|
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if not original_model_task_id:
|
if not original_model_task_id:
|
||||||
raise RuntimeError("original_model_task_id is required")
|
raise RuntimeError("original_model_task_id is required")
|
||||||
|
|
||||||
# Parse part_names from comma-separated string to list
|
|
||||||
part_names_list = None
|
|
||||||
if part_names and part_names.strip():
|
|
||||||
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
|
|
||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
||||||
@ -687,22 +618,9 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
original_model_task_id=original_model_task_id,
|
original_model_task_id=original_model_task_id,
|
||||||
format=format,
|
format=format,
|
||||||
quad=quad if quad else None,
|
quad=quad if quad else None,
|
||||||
force_symmetry=force_symmetry if force_symmetry else None,
|
|
||||||
face_limit=face_limit if face_limit != -1 else None,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
flatten_bottom=flatten_bottom if flatten_bottom else None,
|
|
||||||
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
|
|
||||||
texture_size=texture_size if texture_size != 4096 else None,
|
texture_size=texture_size if texture_size != 4096 else None,
|
||||||
texture_format=texture_format if texture_format != "JPEG" else None,
|
texture_format=texture_format if texture_format != "JPEG" else None,
|
||||||
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
|
|
||||||
scale_factor=scale_factor if scale_factor != 1.0 else None,
|
|
||||||
with_animation=with_animation if with_animation else None,
|
|
||||||
pack_uv=pack_uv if pack_uv else None,
|
|
||||||
bake=bake if bake else None,
|
|
||||||
part_names=part_names_list,
|
|
||||||
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
|
|
||||||
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
|
|
||||||
export_orientation=export_orientation if export_orientation != "default" else None,
|
|
||||||
animate_in_place=animate_in_place if animate_in_place else None,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await poll_until_finished(cls, response, average_duration=30)
|
return await poll_until_finished(cls, response, average_duration=30)
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -19,26 +21,26 @@ from comfy_api_nodes.util import (
|
|||||||
|
|
||||||
class Text2ImageInputField(BaseModel):
|
class Text2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: str | None = Field(None)
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageInputField(BaseModel):
|
class Image2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: str | None = Field(None)
|
negative_prompt: Optional[str] = Field(None)
|
||||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoInputField(BaseModel):
|
class Text2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: str | None = Field(None)
|
negative_prompt: Optional[str] = Field(None)
|
||||||
audio_url: str | None = Field(None)
|
audio_url: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoInputField(BaseModel):
|
class Image2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: str | None = Field(None)
|
negative_prompt: Optional[str] = Field(None)
|
||||||
img_url: str = Field(...)
|
img_url: str = Field(...)
|
||||||
audio_url: str | None = Field(None)
|
audio_url: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Txt2ImageParametersField(BaseModel):
|
class Txt2ImageParametersField(BaseModel):
|
||||||
@ -50,7 +52,7 @@ class Txt2ImageParametersField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Image2ImageParametersField(BaseModel):
|
class Image2ImageParametersField(BaseModel):
|
||||||
size: str | None = Field(None)
|
size: Optional[str] = Field(None)
|
||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
@ -59,21 +61,19 @@ class Image2ImageParametersField(BaseModel):
|
|||||||
class Text2VideoParametersField(BaseModel):
|
class Text2VideoParametersField(BaseModel):
|
||||||
size: str = Field(...)
|
size: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=15)
|
duration: int = Field(5, ge=5, le=10)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
audio: bool = Field(False, description="Should be audio generated automatically")
|
||||||
shot_type: str = Field("single")
|
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoParametersField(BaseModel):
|
class Image2VideoParametersField(BaseModel):
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=15)
|
duration: int = Field(5, ge=5, le=10)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
audio: bool = Field(False, description="Should be audio generated automatically")
|
||||||
shot_type: str = Field("single")
|
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageTaskCreationRequest(BaseModel):
|
class Text2ImageTaskCreationRequest(BaseModel):
|
||||||
@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
class TaskCreationResponse(BaseModel):
|
||||||
output: TaskCreationOutputField | None = Field(None)
|
output: Optional[TaskCreationOutputField] = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
code: str | None = Field(None, description="Error code for the failed request.")
|
code: Optional[str] = Field(None, description="The error code of the failed request.")
|
||||||
message: str | None = Field(None, description="Details about the failed request.")
|
message: Optional[str] = Field(None, description="Details of the failed request.")
|
||||||
|
|
||||||
|
|
||||||
class TaskResult(BaseModel):
|
class TaskResult(BaseModel):
|
||||||
url: str | None = Field(None)
|
url: Optional[str] = Field(None)
|
||||||
code: str | None = Field(None)
|
code: Optional[str] = Field(None)
|
||||||
message: str | None = Field(None)
|
message: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
results: list[TaskResult] | None = Field(None)
|
results: Optional[list[TaskResult]] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
video_url: str | None = Field(None)
|
video_url: Optional[str] = Field(None)
|
||||||
code: str | None = Field(None)
|
code: Optional[str] = Field(None)
|
||||||
message: str | None = Field(None)
|
message: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusResponse(BaseModel):
|
class ImageTaskStatusResponse(BaseModel):
|
||||||
output: ImageTaskStatusOutputField | None = Field(None)
|
output: Optional[ImageTaskStatusOutputField] = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusResponse(BaseModel):
|
class VideoTaskStatusResponse(BaseModel):
|
||||||
output: VideoTaskStatusOutputField | None = Field(None)
|
output: Optional[VideoTaskStatusOutputField] = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToImageApi",
|
node_id="WanTextToImageApi",
|
||||||
display_name="Wan Text to Image",
|
display_name="Wan Text to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates an image based on a text prompt.",
|
description="Generates image based on text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative prompt describing what to avoid.",
|
tooltip="Negative text prompt to guide what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -209,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
display_name="Wan Image to Image",
|
display_name="Wan Image to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates an image from one or two input images and a text prompt. "
|
description="Generates an image from one or two input images and a text prompt. "
|
||||||
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Single-image editing or multi-image fusion. Maximum 2 images.",
|
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative prompt describing what to avoid.",
|
tooltip="Negative text prompt to guide what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
# redo this later as an optional combo of recommended resolutions
|
# redo this later as an optional combo of recommended resolutions
|
||||||
@ -328,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -347,7 +347,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: Input.Image,
|
image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
# width: int = 1024,
|
# width: int = 1024,
|
||||||
@ -357,7 +357,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
):
|
):
|
||||||
n_images = get_number_of_images(image)
|
n_images = get_number_of_images(image)
|
||||||
if n_images not in (1, 2):
|
if n_images not in (1, 2):
|
||||||
raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.")
|
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
||||||
images = []
|
images = []
|
||||||
for i in image:
|
for i in image:
|
||||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||||
@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToVideoApi",
|
node_id="WanTextToVideoApi",
|
||||||
display_name="Wan Text to Video",
|
display_name="Wan Text to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates a video based on a text prompt.",
|
description="Generates video based on text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-t2v-preview", "wan2.6-t2v"],
|
options=["wan2.5-t2v-preview"],
|
||||||
default="wan2.6-t2v",
|
default="wan2.5-t2v-preview",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative prompt describing what to avoid.",
|
tooltip="Negative text prompt to guide what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"1080p: 4:3 (1632x1248)",
|
"1080p: 4:3 (1632x1248)",
|
||||||
"1080p: 3:4 (1248x1632)",
|
"1080p: 3:4 (1248x1632)",
|
||||||
],
|
],
|
||||||
default="720p: 1:1 (960x960)",
|
default="480p: 1:1 (624x624)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=15,
|
max=10,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="A 15-second duration is available only for the Wan 2.6 model.",
|
tooltip="Available durations: 5 and 10 seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If no audio input is provided, generate audio automatically.",
|
tooltip="If there is no audio input, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -477,15 +477,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Combo.Input(
|
|
||||||
"shot_type",
|
|
||||||
options=["single", "multi"],
|
|
||||||
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
|
||||||
"single continuous shot or multiple shots with cuts. "
|
|
||||||
"This parameter takes effect only when prompt_extend is True.",
|
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -506,19 +498,14 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
size: str = "720p: 1:1 (960x960)",
|
size: str = "480p: 1:1 (624x624)",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Input.Audio | None = None,
|
audio: Optional[Input.Audio] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = True,
|
||||||
shot_type: str = "single",
|
|
||||||
):
|
):
|
||||||
if "480p" in size and model == "wan2.6-t2v":
|
|
||||||
raise ValueError("The Wan 2.6 model does not support 480p.")
|
|
||||||
if duration == 15 and model == "wan2.5-t2v-preview":
|
|
||||||
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
|
||||||
width, height = RES_IN_PARENS.search(size).groups()
|
width, height = RES_IN_PARENS.search(size).groups()
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -539,12 +526,11 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
shot_type=shot_type,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -563,12 +549,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanImageToVideoApi",
|
node_id="WanImageToVideoApi",
|
||||||
display_name="Wan Image to Video",
|
display_name="Wan Image to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates a video from the first frame and a text prompt.",
|
description="Generates video based on the first frame and text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-i2v-preview", "wan2.6-i2v"],
|
options=["wan2.5-i2v-preview"],
|
||||||
default="wan2.6-i2v",
|
default="wan2.5-i2v-preview",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
@ -578,13 +564,13 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative prompt describing what to avoid.",
|
tooltip="Negative text prompt to guide what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -594,23 +580,23 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"720P",
|
"720P",
|
||||||
"1080P",
|
"1080P",
|
||||||
],
|
],
|
||||||
default="720P",
|
default="480P",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=15,
|
max=10,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration 15 available only for WAN2.6 model.",
|
tooltip="Available durations: 5 and 10 seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -627,7 +613,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If no audio input is provided, generate audio automatically.",
|
tooltip="If there is no audio input, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -638,15 +624,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Combo.Input(
|
|
||||||
"shot_type",
|
|
||||||
options=["single", "multi"],
|
|
||||||
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
|
||||||
"single continuous shot or multiple shots with cuts. "
|
|
||||||
"This parameter takes effect only when prompt_extend is True.",
|
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -665,24 +643,19 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: Input.Image,
|
image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
resolution: str = "720P",
|
resolution: str = "480P",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Input.Audio | None = None,
|
audio: Optional[Input.Audio] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = True,
|
||||||
shot_type: str = "single",
|
|
||||||
):
|
):
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Exactly one input image is required.")
|
raise ValueError("Exactly one input image is required.")
|
||||||
if "480P" in resolution and model == "wan2.6-i2v":
|
|
||||||
raise ValueError("The Wan 2.6 model does not support 480P.")
|
|
||||||
if duration == 15 and model == "wan2.5-i2v-preview":
|
|
||||||
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
|
||||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -704,12 +677,11 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
shot_type=shot_type,
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
|
|||||||
@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
|||||||
return img_byte_arr
|
return img_byte_arr
|
||||||
|
|
||||||
|
|
||||||
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||||
samples = image.movedim(-1, 1)
|
samples = image.movedim(-1, 1)
|
||||||
total = int(total_pixels)
|
total = int(total_pixels)
|
||||||
|
|||||||
@ -1,291 +0,0 @@
|
|||||||
"""
|
|
||||||
Job utilities for the /api/jobs endpoint.
|
|
||||||
Provides normalization and helper functions for job status tracking.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from comfy_api.internal import prune_dict
|
|
||||||
|
|
||||||
|
|
||||||
class JobStatus:
|
|
||||||
"""Job status constants."""
|
|
||||||
PENDING = 'pending'
|
|
||||||
IN_PROGRESS = 'in_progress'
|
|
||||||
COMPLETED = 'completed'
|
|
||||||
FAILED = 'failed'
|
|
||||||
|
|
||||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
|
|
||||||
|
|
||||||
|
|
||||||
# Media types that can be previewed in the frontend
|
|
||||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
|
||||||
|
|
||||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
|
||||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
|
||||||
"""Extract create_time and workflow_id from extra_data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (create_time, workflow_id)
|
|
||||||
"""
|
|
||||||
create_time = extra_data.get('create_time')
|
|
||||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
|
||||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
|
||||||
return create_time, workflow_id
|
|
||||||
|
|
||||||
|
|
||||||
def is_previewable(media_type: str, item: dict) -> bool:
|
|
||||||
"""
|
|
||||||
Check if an output item is previewable.
|
|
||||||
Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts
|
|
||||||
Maintains backwards compatibility with existing logic.
|
|
||||||
|
|
||||||
Priority:
|
|
||||||
1. media_type is 'images', 'video', or 'audio'
|
|
||||||
2. format field starts with 'video/' or 'audio/'
|
|
||||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
|
||||||
"""
|
|
||||||
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check format field (MIME type).
|
|
||||||
# Maintains backwards compatibility with how custom node outputs are handled in the frontend.
|
|
||||||
fmt = item.get('format', '')
|
|
||||||
if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check for 3D files by extension
|
|
||||||
filename = item.get('filename', '').lower()
|
|
||||||
if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_queue_item(item: tuple, status: str) -> dict:
|
|
||||||
"""Convert queue item tuple to unified job dict.
|
|
||||||
|
|
||||||
Expects item with sensitive data already removed (5 elements).
|
|
||||||
"""
|
|
||||||
priority, prompt_id, _, extra_data, _ = item
|
|
||||||
create_time, workflow_id = _extract_job_metadata(extra_data)
|
|
||||||
|
|
||||||
return prune_dict({
|
|
||||||
'id': prompt_id,
|
|
||||||
'status': status,
|
|
||||||
'priority': priority,
|
|
||||||
'create_time': create_time,
|
|
||||||
'outputs_count': 0,
|
|
||||||
'workflow_id': workflow_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict:
|
|
||||||
"""Convert history item dict to unified job dict.
|
|
||||||
|
|
||||||
History items have sensitive data already removed (prompt tuple has 5 elements).
|
|
||||||
"""
|
|
||||||
prompt_tuple = history_item['prompt']
|
|
||||||
priority, _, prompt, extra_data, _ = prompt_tuple
|
|
||||||
create_time, workflow_id = _extract_job_metadata(extra_data)
|
|
||||||
|
|
||||||
status_info = history_item.get('status', {})
|
|
||||||
status_str = status_info.get('status_str') if status_info else None
|
|
||||||
if status_str == 'success':
|
|
||||||
status = JobStatus.COMPLETED
|
|
||||||
elif status_str == 'error':
|
|
||||||
status = JobStatus.FAILED
|
|
||||||
else:
|
|
||||||
status = JobStatus.COMPLETED
|
|
||||||
|
|
||||||
outputs = history_item.get('outputs', {})
|
|
||||||
outputs_count, preview_output = get_outputs_summary(outputs)
|
|
||||||
|
|
||||||
execution_error = None
|
|
||||||
execution_start_time = None
|
|
||||||
execution_end_time = None
|
|
||||||
if status_info:
|
|
||||||
messages = status_info.get('messages', [])
|
|
||||||
for entry in messages:
|
|
||||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
|
||||||
event_name, event_data = entry[0], entry[1]
|
|
||||||
if isinstance(event_data, dict):
|
|
||||||
if event_name == 'execution_start':
|
|
||||||
execution_start_time = event_data.get('timestamp')
|
|
||||||
elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'):
|
|
||||||
execution_end_time = event_data.get('timestamp')
|
|
||||||
if event_name == 'execution_error':
|
|
||||||
execution_error = event_data
|
|
||||||
|
|
||||||
job = prune_dict({
|
|
||||||
'id': prompt_id,
|
|
||||||
'status': status,
|
|
||||||
'priority': priority,
|
|
||||||
'create_time': create_time,
|
|
||||||
'execution_start_time': execution_start_time,
|
|
||||||
'execution_end_time': execution_end_time,
|
|
||||||
'execution_error': execution_error,
|
|
||||||
'outputs_count': outputs_count,
|
|
||||||
'preview_output': preview_output,
|
|
||||||
'workflow_id': workflow_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
if include_outputs:
|
|
||||||
job['outputs'] = outputs
|
|
||||||
job['execution_status'] = status_info
|
|
||||||
job['workflow'] = {
|
|
||||||
'prompt': prompt,
|
|
||||||
'extra_data': extra_data,
|
|
||||||
}
|
|
||||||
|
|
||||||
return job
|
|
||||||
|
|
||||||
|
|
||||||
def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
|
||||||
"""
|
|
||||||
Count outputs and find preview in a single pass.
|
|
||||||
Returns (outputs_count, preview_output).
|
|
||||||
|
|
||||||
Preview priority (matching frontend):
|
|
||||||
1. type="output" with previewable media
|
|
||||||
2. Any previewable media
|
|
||||||
"""
|
|
||||||
count = 0
|
|
||||||
preview_output = None
|
|
||||||
fallback_preview = None
|
|
||||||
|
|
||||||
for node_id, node_outputs in outputs.items():
|
|
||||||
if not isinstance(node_outputs, dict):
|
|
||||||
continue
|
|
||||||
for media_type, items in node_outputs.items():
|
|
||||||
# 'animated' is a boolean flag, not actual output items
|
|
||||||
if media_type == 'animated' or not isinstance(items, list):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if preview_output is None and is_previewable(media_type, item):
|
|
||||||
enriched = {
|
|
||||||
**item,
|
|
||||||
'nodeId': node_id,
|
|
||||||
'mediaType': media_type
|
|
||||||
}
|
|
||||||
if item.get('type') == 'output':
|
|
||||||
preview_output = enriched
|
|
||||||
elif fallback_preview is None:
|
|
||||||
fallback_preview = enriched
|
|
||||||
|
|
||||||
return count, preview_output or fallback_preview
|
|
||||||
|
|
||||||
|
|
||||||
def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]:
|
|
||||||
"""Sort jobs list by specified field and order."""
|
|
||||||
reverse = (sort_order == 'desc')
|
|
||||||
|
|
||||||
if sort_by == 'execution_duration':
|
|
||||||
def get_sort_key(job):
|
|
||||||
start = job.get('execution_start_time', 0)
|
|
||||||
end = job.get('execution_end_time', 0)
|
|
||||||
return end - start if end and start else 0
|
|
||||||
else:
|
|
||||||
def get_sort_key(job):
|
|
||||||
return job.get('create_time', 0)
|
|
||||||
|
|
||||||
return sorted(jobs, key=get_sort_key, reverse=reverse)
|
|
||||||
|
|
||||||
|
|
||||||
def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get a single job by prompt_id from history or queue.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt_id: The prompt ID to look up
|
|
||||||
running: List of currently running queue items
|
|
||||||
queued: List of pending queue items
|
|
||||||
history: Dict of history items keyed by prompt_id
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Job dict with full details, or None if not found
|
|
||||||
"""
|
|
||||||
if prompt_id in history:
|
|
||||||
return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True)
|
|
||||||
|
|
||||||
for item in running:
|
|
||||||
if item[1] == prompt_id:
|
|
||||||
return normalize_queue_item(item, JobStatus.IN_PROGRESS)
|
|
||||||
|
|
||||||
for item in queued:
|
|
||||||
if item[1] == prompt_id:
|
|
||||||
return normalize_queue_item(item, JobStatus.PENDING)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_jobs(
|
|
||||||
running: list,
|
|
||||||
queued: list,
|
|
||||||
history: dict,
|
|
||||||
status_filter: Optional[list[str]] = None,
|
|
||||||
workflow_id: Optional[str] = None,
|
|
||||||
sort_by: str = "created_at",
|
|
||||||
sort_order: str = "desc",
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
offset: int = 0
|
|
||||||
) -> tuple[list[dict], int]:
|
|
||||||
"""
|
|
||||||
Get all jobs (running, pending, completed) with filtering and sorting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
running: List of currently running queue items
|
|
||||||
queued: List of pending queue items
|
|
||||||
history: Dict of history items keyed by prompt_id
|
|
||||||
status_filter: List of statuses to include (from JobStatus.ALL)
|
|
||||||
workflow_id: Filter by workflow ID
|
|
||||||
sort_by: Field to sort by ('created_at', 'execution_duration')
|
|
||||||
sort_order: 'asc' or 'desc'
|
|
||||||
limit: Maximum number of items to return
|
|
||||||
offset: Number of items to skip
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (jobs_list, total_count)
|
|
||||||
"""
|
|
||||||
jobs = []
|
|
||||||
|
|
||||||
if status_filter is None:
|
|
||||||
status_filter = JobStatus.ALL
|
|
||||||
|
|
||||||
if JobStatus.IN_PROGRESS in status_filter:
|
|
||||||
for item in running:
|
|
||||||
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
|
|
||||||
|
|
||||||
if JobStatus.PENDING in status_filter:
|
|
||||||
for item in queued:
|
|
||||||
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
|
|
||||||
|
|
||||||
include_completed = JobStatus.COMPLETED in status_filter
|
|
||||||
include_failed = JobStatus.FAILED in status_filter
|
|
||||||
if include_completed or include_failed:
|
|
||||||
for prompt_id, history_item in history.items():
|
|
||||||
is_failed = history_item.get('status', {}).get('status_str') == 'error'
|
|
||||||
if (is_failed and include_failed) or (not is_failed and include_completed):
|
|
||||||
jobs.append(normalize_history_item(prompt_id, history_item))
|
|
||||||
|
|
||||||
if workflow_id:
|
|
||||||
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
|
||||||
|
|
||||||
jobs = apply_sorting(jobs, sort_by, sort_order)
|
|
||||||
|
|
||||||
total_count = len(jobs)
|
|
||||||
|
|
||||||
if offset > 0:
|
|
||||||
jobs = jobs[offset:]
|
|
||||||
if limit is not None:
|
|
||||||
jobs = jobs[:limit]
|
|
||||||
|
|
||||||
return (jobs, total_count)
|
|
||||||
@ -659,40 +659,6 @@ class SamplerSASolver(io.ComfyNode):
|
|||||||
get_sampler = execute
|
get_sampler = execute
|
||||||
|
|
||||||
|
|
||||||
class SamplerSEEDS2(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="SamplerSEEDS2",
|
|
||||||
category="sampling/custom_sampling/samplers",
|
|
||||||
inputs=[
|
|
||||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
|
||||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
|
||||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
|
||||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
|
||||||
],
|
|
||||||
outputs=[io.Sampler.Output()],
|
|
||||||
description=(
|
|
||||||
"This sampler node can represent multiple samplers:\n\n"
|
|
||||||
"seeds_2\n"
|
|
||||||
"- default setting\n\n"
|
|
||||||
"exp_heun_2_x0\n"
|
|
||||||
"- solver_type=phi_2, r=1.0, eta=0.0\n\n"
|
|
||||||
"exp_heun_2_x0_sde\n"
|
|
||||||
"- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
|
||||||
sampler_name = "seeds_2"
|
|
||||||
sampler = comfy.samplers.ksampler(
|
|
||||||
sampler_name,
|
|
||||||
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
|
||||||
)
|
|
||||||
return io.NodeOutput(sampler)
|
|
||||||
|
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -1030,7 +996,6 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
SamplerDPMAdaptative,
|
SamplerDPMAdaptative,
|
||||||
SamplerER_SDE,
|
SamplerER_SDE,
|
||||||
SamplerSASolver,
|
SamplerSASolver,
|
||||||
SamplerSEEDS2,
|
|
||||||
SplitSigmas,
|
SplitSigmas,
|
||||||
SplitSigmasDenoise,
|
SplitSigmasDenoise,
|
||||||
FlipSigmas,
|
FlipSigmas,
|
||||||
|
|||||||
@ -1125,99 +1125,6 @@ class MergeTextListsNode(TextProcessingNode):
|
|||||||
# ========== Training Dataset Nodes ==========
|
# ========== Training Dataset Nodes ==========
|
||||||
|
|
||||||
|
|
||||||
class ResolutionBucket(io.ComfyNode):
|
|
||||||
"""Bucket latents and conditions by resolution for efficient batch training."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="ResolutionBucket",
|
|
||||||
display_name="Resolution Bucket",
|
|
||||||
category="dataset",
|
|
||||||
is_experimental=True,
|
|
||||||
is_input_list=True,
|
|
||||||
inputs=[
|
|
||||||
io.Latent.Input(
|
|
||||||
"latents",
|
|
||||||
tooltip="List of latent dicts to bucket by resolution.",
|
|
||||||
),
|
|
||||||
io.Conditioning.Input(
|
|
||||||
"conditioning",
|
|
||||||
tooltip="List of conditioning lists (must match latents length).",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Latent.Output(
|
|
||||||
display_name="latents",
|
|
||||||
is_output_list=True,
|
|
||||||
tooltip="List of batched latent dicts, one per resolution bucket.",
|
|
||||||
),
|
|
||||||
io.Conditioning.Output(
|
|
||||||
display_name="conditioning",
|
|
||||||
is_output_list=True,
|
|
||||||
tooltip="List of condition lists, one per resolution bucket.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, latents, conditioning):
|
|
||||||
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
|
||||||
# conditioning: list[list[cond]]
|
|
||||||
|
|
||||||
# Validate lengths match
|
|
||||||
if len(latents) != len(conditioning):
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flatten latents and conditions to individual samples
|
|
||||||
flat_latents = [] # list of (C, H, W) tensors
|
|
||||||
flat_conditions = [] # list of condition lists
|
|
||||||
|
|
||||||
for latent_dict, cond in zip(latents, conditioning):
|
|
||||||
samples = latent_dict["samples"] # (B, C, H, W)
|
|
||||||
batch_size = samples.shape[0]
|
|
||||||
|
|
||||||
# cond is a list of conditions with length == batch_size
|
|
||||||
for i in range(batch_size):
|
|
||||||
flat_latents.append(samples[i]) # (C, H, W)
|
|
||||||
flat_conditions.append(cond[i]) # single condition
|
|
||||||
|
|
||||||
# Group by resolution (H, W)
|
|
||||||
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
|
||||||
|
|
||||||
for latent, cond in zip(flat_latents, flat_conditions):
|
|
||||||
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
|
||||||
h, w = latent.shape[-2], latent.shape[-1]
|
|
||||||
key = (h, w)
|
|
||||||
|
|
||||||
if key not in buckets:
|
|
||||||
buckets[key] = {"latents": [], "conditions": []}
|
|
||||||
|
|
||||||
buckets[key]["latents"].append(latent)
|
|
||||||
buckets[key]["conditions"].append(cond)
|
|
||||||
|
|
||||||
# Convert buckets to output format
|
|
||||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
|
||||||
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
|
||||||
|
|
||||||
for (h, w), bucket_data in buckets.items():
|
|
||||||
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
|
||||||
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
|
||||||
output_latents.append({"samples": stacked_latents})
|
|
||||||
|
|
||||||
# Conditions stay as list of condition lists
|
|
||||||
output_conditions.append(bucket_data["conditions"])
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
|
||||||
return io.NodeOutput(output_latents, output_conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class MakeTrainingDataset(io.ComfyNode):
|
class MakeTrainingDataset(io.ComfyNode):
|
||||||
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||||
|
|
||||||
@ -1466,7 +1373,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
|||||||
shard_path = os.path.join(dataset_dir, shard_file)
|
shard_path = os.path.join(dataset_dir, shard_file)
|
||||||
|
|
||||||
with open(shard_path, "rb") as f:
|
with open(shard_path, "rb") as f:
|
||||||
shard_data = torch.load(f)
|
shard_data = torch.load(f, weights_only=True)
|
||||||
|
|
||||||
all_latents.extend(shard_data["latents"])
|
all_latents.extend(shard_data["latents"])
|
||||||
all_conditioning.extend(shard_data["conditioning"])
|
all_conditioning.extend(shard_data["conditioning"])
|
||||||
@ -1518,7 +1425,6 @@ class DatasetExtension(ComfyExtension):
|
|||||||
MakeTrainingDataset,
|
MakeTrainingDataset,
|
||||||
SaveTrainingDataset,
|
SaveTrainingDataset,
|
||||||
LoadTrainingDataset,
|
LoadTrainingDataset,
|
||||||
ResolutionBucket,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -154,13 +154,12 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="FluxKontextMultiReferenceLatentMethod",
|
node_id="FluxKontextMultiReferenceLatentMethod",
|
||||||
display_name="Edit Model Reference Method",
|
|
||||||
category="advanced/conditioning/flux",
|
category="advanced/conditioning/flux",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("conditioning"),
|
io.Conditioning.Input("conditioning"),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"reference_latents_method",
|
"reference_latents_method",
|
||||||
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
|
options=["offset", "index", "uxo/uno"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
|||||||
@ -243,16 +243,7 @@ class ModelPatchLoader:
|
|||||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||||
sd = z_image_convert(sd)
|
sd = z_image_convert(sd)
|
||||||
config = {}
|
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
|
|
||||||
config['n_control_layers'] = 15
|
|
||||||
config['additional_in_dim'] = 17
|
|
||||||
config['refiner_control'] = True
|
|
||||||
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
|
|
||||||
if ref_weight is not None:
|
|
||||||
if torch.count_nonzero(ref_weight) == 0:
|
|
||||||
config['broken'] = True
|
|
||||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
@ -306,122 +297,62 @@ class DiffSynthCnetPatch:
|
|||||||
return [self.model_patch]
|
return [self.model_patch]
|
||||||
|
|
||||||
class ZImageControlPatch:
|
class ZImageControlPatch:
|
||||||
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
|
def __init__(self, model_patch, vae, image, strength):
|
||||||
self.model_patch = model_patch
|
self.model_patch = model_patch
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.image = image
|
self.image = image
|
||||||
self.inpaint_image = inpaint_image
|
|
||||||
self.mask = mask
|
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
|
self.encoded_image = self.encode_latent_cond(image)
|
||||||
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
skip_encoding = False
|
|
||||||
if self.image is not None and self.inpaint_image is not None:
|
|
||||||
if self.image.shape != self.inpaint_image.shape:
|
|
||||||
skip_encoding = True
|
|
||||||
|
|
||||||
if skip_encoding:
|
|
||||||
self.encoded_image = None
|
|
||||||
else:
|
|
||||||
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
|
|
||||||
if self.image is None:
|
|
||||||
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
|
|
||||||
else:
|
|
||||||
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
|
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
def encode_latent_cond(self, control_image=None, inpaint_image=None):
|
def encode_latent_cond(self, image):
|
||||||
latent_image = None
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
||||||
if control_image is not None:
|
return latent_image
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
|
||||||
|
|
||||||
if self.is_inpaint:
|
|
||||||
if inpaint_image is None:
|
|
||||||
inpaint_image = torch.ones_like(control_image) * 0.5
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
|
|
||||||
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
|
|
||||||
|
|
||||||
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
|
||||||
|
|
||||||
if self.mask is None:
|
|
||||||
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
|
||||||
else:
|
|
||||||
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
|
||||||
|
|
||||||
if latent_image is None:
|
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
|
||||||
|
|
||||||
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
|
||||||
else:
|
|
||||||
return latent_image
|
|
||||||
|
|
||||||
def __call__(self, kwargs):
|
def __call__(self, kwargs):
|
||||||
x = kwargs.get("x")
|
x = kwargs.get("x")
|
||||||
img = kwargs.get("img")
|
img = kwargs.get("img")
|
||||||
img_input = kwargs.get("img_input")
|
|
||||||
txt = kwargs.get("txt")
|
txt = kwargs.get("txt")
|
||||||
pe = kwargs.get("pe")
|
pe = kwargs.get("pe")
|
||||||
vec = kwargs.get("vec")
|
vec = kwargs.get("vec")
|
||||||
block_index = kwargs.get("block_index")
|
block_index = kwargs.get("block_index")
|
||||||
block_type = kwargs.get("block_type", "")
|
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
image_scaled = None
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
if self.image is not None:
|
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
|
||||||
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
|
|
||||||
|
|
||||||
inpaint_scaled = None
|
|
||||||
if self.inpaint_image is not None:
|
|
||||||
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
|
||||||
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
|
|
||||||
|
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
|
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
cnet_blocks = self.model_patch.model.n_control_layers
|
cnet_index = (block_index // 5)
|
||||||
div = round(30 / cnet_blocks)
|
cnet_index_float = (block_index / 5)
|
||||||
|
|
||||||
cnet_index = (block_index // div)
|
|
||||||
cnet_index_float = (block_index / div)
|
|
||||||
|
|
||||||
kwargs.pop("img") # we do ops in place
|
kwargs.pop("img") # we do ops in place
|
||||||
kwargs.pop("txt")
|
kwargs.pop("txt")
|
||||||
|
|
||||||
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
if cnet_index_float > (cnet_blocks - 1):
|
if cnet_index_float > (cnet_blocks - 1):
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||||
if block_type == "noise_refiner":
|
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
|
||||||
else:
|
|
||||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
|
||||||
|
|
||||||
if block_type == "noise_refiner":
|
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||||
next_layer = self.temp_data[0] + 1
|
next_layer = self.temp_data[0] + 1
|
||||||
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
if self.temp_data[1][0] is not None:
|
|
||||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
|
||||||
else:
|
|
||||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
|
||||||
next_layer = self.temp_data[0] + 1
|
|
||||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
|
||||||
|
|
||||||
if cnet_index_float == self.temp_data[0]:
|
if cnet_index_float == self.temp_data[0]:
|
||||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
if cnet_blocks == self.temp_data[0] + 1:
|
if cnet_blocks == self.temp_data[0] + 1:
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def to(self, device_or_dtype):
|
def to(self, device_or_dtype):
|
||||||
if isinstance(device_or_dtype, torch.device):
|
if isinstance(device_or_dtype, torch.device):
|
||||||
if self.encoded_image is not None:
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -444,12 +375,9 @@ class QwenImageDiffsynthControlnet:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders/qwen"
|
CATEGORY = "advanced/loaders/qwen"
|
||||||
|
|
||||||
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
|
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
||||||
model_patched = model.clone()
|
model_patched = model.clone()
|
||||||
if image is not None:
|
image = image[:, :, :, :3]
|
||||||
image = image[:, :, :, :3]
|
|
||||||
if inpaint_image is not None:
|
|
||||||
inpaint_image = inpaint_image[:, :, :, :3]
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
@ -458,24 +386,11 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
|
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
||||||
model_patched.set_model_noise_refiner_patch(patch)
|
|
||||||
model_patched.set_model_double_block_patch(patch)
|
|
||||||
else:
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "model": ("MODEL",),
|
|
||||||
"model_patch": ("MODEL_PATCH",),
|
|
||||||
"vae": ("VAE",),
|
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
|
||||||
},
|
|
||||||
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
|
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders/zimage"
|
|
||||||
|
|
||||||
class UsoStyleProjectorPatch:
|
class UsoStyleProjectorPatch:
|
||||||
def __init__(self, model_patch, encoded_image):
|
def __init__(self, model_patch, encoded_image):
|
||||||
@ -523,6 +438,5 @@ class USOStyleReference:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
"ZImageFunControlnet": ZImageFunControlnet,
|
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -221,7 +221,6 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||||
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(),
|
io.Image.Output(),
|
||||||
@ -229,15 +228,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
|
||||||
samples = image.movedim(-1,1)
|
samples = image.movedim(-1,1)
|
||||||
total = megapixels * 1024 * 1024
|
total = int(megapixels * 1024 * 1024)
|
||||||
|
|
||||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
|
width = round(samples.shape[3] * scale_by)
|
||||||
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
|
height = round(samples.shape[2] * scale_by)
|
||||||
|
|
||||||
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
|
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,6 @@ from typing_extensions import override
|
|||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||||
|
|
||||||
def skip_torch_compile_dict(guard_entries):
|
|
||||||
return [("transformer_options" not in entry.name) for entry in guard_entries]
|
|
||||||
|
|
||||||
class TorchCompileModel(io.ComfyNode):
|
class TorchCompileModel(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -25,7 +23,7 @@ class TorchCompileModel(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, backend) -> io.NodeOutput:
|
def execute(cls, model, backend) -> io.NodeOutput:
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
|
set_torch_compile_wrapper(model=m, backend=backend)
|
||||||
return io.NodeOutput(m)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from PIL import Image, ImageDraw, ImageFont
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sampler_helpers
|
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -22,68 +21,6 @@ from comfy_api.latest import ComfyExtension, io, ui
|
|||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
|
|
||||||
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|
||||||
"""
|
|
||||||
CFGGuider with modifications for training specific logic
|
|
||||||
"""
|
|
||||||
def outer_sample(
|
|
||||||
self,
|
|
||||||
noise,
|
|
||||||
latent_image,
|
|
||||||
sampler,
|
|
||||||
sigmas,
|
|
||||||
denoise_mask=None,
|
|
||||||
callback=None,
|
|
||||||
disable_pbar=False,
|
|
||||||
seed=None,
|
|
||||||
latent_shapes=None,
|
|
||||||
):
|
|
||||||
self.inner_model, self.conds, self.loaded_models = (
|
|
||||||
comfy.sampler_helpers.prepare_sampling(
|
|
||||||
self.model_patcher,
|
|
||||||
noise.shape,
|
|
||||||
self.conds,
|
|
||||||
self.model_options,
|
|
||||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
|
||||||
)
|
|
||||||
)
|
|
||||||
device = self.model_patcher.load_device
|
|
||||||
|
|
||||||
if denoise_mask is not None:
|
|
||||||
denoise_mask = comfy.sampler_helpers.prepare_mask(
|
|
||||||
denoise_mask, noise.shape, device
|
|
||||||
)
|
|
||||||
|
|
||||||
noise = noise.to(device)
|
|
||||||
latent_image = latent_image.to(device)
|
|
||||||
sigmas = sigmas.to(device)
|
|
||||||
comfy.samplers.cast_to_load_options(
|
|
||||||
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model_patcher.pre_run()
|
|
||||||
output = self.inner_sample(
|
|
||||||
noise,
|
|
||||||
latent_image,
|
|
||||||
device,
|
|
||||||
sampler,
|
|
||||||
sigmas,
|
|
||||||
denoise_mask,
|
|
||||||
callback,
|
|
||||||
disable_pbar,
|
|
||||||
seed,
|
|
||||||
latent_shapes=latent_shapes,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self.model_patcher.cleanup()
|
|
||||||
|
|
||||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
|
||||||
del self.inner_model
|
|
||||||
del self.loaded_models
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
@ -128,7 +65,6 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
seed=0,
|
seed=0,
|
||||||
training_dtype=torch.bfloat16,
|
training_dtype=torch.bfloat16,
|
||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
bucket_latents=None,
|
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -139,28 +75,6 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||||
# Bucket mode data
|
|
||||||
self.bucket_latents: list[torch.Tensor] | None = (
|
|
||||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
|
||||||
)
|
|
||||||
# Precompute bucket offsets and weights for sampling
|
|
||||||
if bucket_latents is not None:
|
|
||||||
self._init_bucket_data(bucket_latents)
|
|
||||||
else:
|
|
||||||
self.bucket_offsets = None
|
|
||||||
self.bucket_weights = None
|
|
||||||
self.num_images = None
|
|
||||||
|
|
||||||
def _init_bucket_data(self, bucket_latents):
|
|
||||||
"""Initialize bucket offsets and weights for sampling."""
|
|
||||||
self.bucket_offsets = [0]
|
|
||||||
bucket_sizes = []
|
|
||||||
for lat in bucket_latents:
|
|
||||||
bucket_sizes.append(lat.shape[0])
|
|
||||||
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
|
||||||
self.num_images = self.bucket_offsets[-1]
|
|
||||||
# Weights for sampling buckets proportional to their size
|
|
||||||
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
|
||||||
|
|
||||||
def fwd_bwd(
|
def fwd_bwd(
|
||||||
self,
|
self,
|
||||||
@ -201,108 +115,6 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
bwd_loss.backward()
|
bwd_loss.backward()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
|
||||||
"""Generate random sigma values for a batch."""
|
|
||||||
batch_sigmas = [
|
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
)
|
|
||||||
for _ in range(batch_size)
|
|
||||||
]
|
|
||||||
return torch.tensor(batch_sigmas).to(device)
|
|
||||||
|
|
||||||
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
|
|
||||||
"""Execute one training step in bucket mode."""
|
|
||||||
# Sample bucket (weighted by size), then sample batch from bucket
|
|
||||||
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
|
||||||
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
|
||||||
bucket_size = bucket_latent.shape[0]
|
|
||||||
bucket_offset = self.bucket_offsets[bucket_idx]
|
|
||||||
|
|
||||||
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
|
||||||
actual_batch_size = min(self.batch_size, bucket_size)
|
|
||||||
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
|
||||||
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
|
||||||
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
|
||||||
|
|
||||||
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
|
||||||
batch_latent.device
|
|
||||||
)
|
|
||||||
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
|
|
||||||
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
batch_latent,
|
|
||||||
cond, # Use flattened cond with absolute indices
|
|
||||||
absolute_indices,
|
|
||||||
extra_args,
|
|
||||||
self.num_images,
|
|
||||||
bwd=True,
|
|
||||||
)
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
|
||||||
|
|
||||||
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
|
||||||
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
|
|
||||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
|
||||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
|
||||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
|
||||||
batch_latent.device
|
|
||||||
)
|
|
||||||
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
|
|
||||||
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
batch_latent,
|
|
||||||
cond,
|
|
||||||
indicies,
|
|
||||||
extra_args,
|
|
||||||
dataset_size,
|
|
||||||
bwd=True,
|
|
||||||
)
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
||||||
|
|
||||||
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
|
||||||
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
|
|
||||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
|
||||||
total_loss = 0
|
|
||||||
for index in indicies:
|
|
||||||
single_latent = self.real_dataset[index].to(latent_image)
|
|
||||||
batch_noise = noisegen.generate_noise(
|
|
||||||
{"samples": single_latent}
|
|
||||||
).to(single_latent.device)
|
|
||||||
batch_sigmas = (
|
|
||||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
||||||
torch.rand((1,)).item()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
|
||||||
loss = self.fwd_bwd(
|
|
||||||
model_wrap,
|
|
||||||
batch_sigmas,
|
|
||||||
batch_noise,
|
|
||||||
single_latent,
|
|
||||||
cond,
|
|
||||||
[index],
|
|
||||||
extra_args,
|
|
||||||
dataset_size,
|
|
||||||
bwd=False,
|
|
||||||
)
|
|
||||||
total_loss += loss
|
|
||||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
|
||||||
total_loss.backward()
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(total_loss.item())
|
|
||||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
model_wrap,
|
model_wrap,
|
||||||
@ -330,18 +142,70 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
||||||
self.seed + i * 1000
|
self.seed + i * 1000
|
||||||
)
|
)
|
||||||
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||||
|
|
||||||
if self.bucket_latents is not None:
|
if self.real_dataset is None:
|
||||||
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
elif self.real_dataset is None:
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||||
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
batch_latent.device
|
||||||
|
)
|
||||||
|
batch_sigmas = [
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
for _ in range(min(self.batch_size, dataset_size))
|
||||||
|
]
|
||||||
|
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||||
|
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond,
|
||||||
|
indicies,
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd=True,
|
||||||
|
)
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
else:
|
else:
|
||||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
total_loss = 0
|
||||||
|
for index in indicies:
|
||||||
|
single_latent = self.real_dataset[index].to(latent_image)
|
||||||
|
batch_noise = noisegen.generate_noise(
|
||||||
|
{"samples": single_latent}
|
||||||
|
).to(single_latent.device)
|
||||||
|
batch_sigmas = (
|
||||||
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
|
torch.rand((1,)).item()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||||
|
loss = self.fwd_bwd(
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
single_latent,
|
||||||
|
cond,
|
||||||
|
[index],
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd=False,
|
||||||
|
)
|
||||||
|
total_loss += loss
|
||||||
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||||
|
total_loss.backward()
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(total_loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||||
|
|
||||||
if (i + 1) % self.grad_acc == 0:
|
if (i + 1) % self.grad_acc == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
ui_pbar.update(1)
|
ui_pbar.update(1)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return torch.zeros_like(latent_image)
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
@ -419,364 +283,6 @@ def unpatch(m):
|
|||||||
del m.org_forward
|
del m.org_forward
|
||||||
|
|
||||||
|
|
||||||
def _process_latents_bucket_mode(latents):
|
|
||||||
"""Process latents for bucket mode training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of latent tensors
|
|
||||||
"""
|
|
||||||
bucket_latents = []
|
|
||||||
for latent_dict in latents:
|
|
||||||
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
|
||||||
return bucket_latents
|
|
||||||
|
|
||||||
|
|
||||||
def _process_latents_standard_mode(latents):
|
|
||||||
"""Process latents for standard (non-bucket) mode training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
latents: list of latent dicts or single latent dict
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Processed latents (tensor or list of tensors)
|
|
||||||
"""
|
|
||||||
if len(latents) == 1:
|
|
||||||
return latents[0]["samples"] # Single latent dict
|
|
||||||
|
|
||||||
latent_list = []
|
|
||||||
for latent in latents:
|
|
||||||
latent = latent["samples"]
|
|
||||||
bs = latent.shape[0]
|
|
||||||
if bs != 1:
|
|
||||||
for sub_latent in latent:
|
|
||||||
latent_list.append(sub_latent[None])
|
|
||||||
else:
|
|
||||||
latent_list.append(latent)
|
|
||||||
return latent_list
|
|
||||||
|
|
||||||
|
|
||||||
def _process_conditioning(positive):
|
|
||||||
"""Process conditioning - either single list or list of lists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
positive: list of conditioning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Flattened conditioning list
|
|
||||||
"""
|
|
||||||
if len(positive) == 1:
|
|
||||||
return positive[0] # Single conditioning list
|
|
||||||
|
|
||||||
# Multiple conditioning lists - flatten
|
|
||||||
flat_positive = []
|
|
||||||
for cond in positive:
|
|
||||||
if isinstance(cond, list):
|
|
||||||
flat_positive.extend(cond)
|
|
||||||
else:
|
|
||||||
flat_positive.append(cond)
|
|
||||||
return flat_positive
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
|
||||||
"""Convert latents to dtype and compute image counts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
latents: Latents (tensor, list of tensors, or bucket list)
|
|
||||||
dtype: Target dtype
|
|
||||||
bucket_mode: Whether bucket mode is enabled
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (processed_latents, num_images, multi_res)
|
|
||||||
"""
|
|
||||||
if bucket_mode:
|
|
||||||
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
|
||||||
latents = [t.to(dtype) for t in latents]
|
|
||||||
num_buckets = len(latents)
|
|
||||||
num_images = sum(t.shape[0] for t in latents)
|
|
||||||
multi_res = False # Not using multi_res path in bucket mode
|
|
||||||
|
|
||||||
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
|
||||||
for i, lat in enumerate(latents):
|
|
||||||
logging.info(f" Bucket {i}: shape {lat.shape}")
|
|
||||||
return latents, num_images, multi_res
|
|
||||||
|
|
||||||
# Non-bucket mode
|
|
||||||
if isinstance(latents, list):
|
|
||||||
all_shapes = set()
|
|
||||||
latents = [t.to(dtype) for t in latents]
|
|
||||||
for latent in latents:
|
|
||||||
all_shapes.add(latent.shape)
|
|
||||||
logging.info(f"Latent shapes: {all_shapes}")
|
|
||||||
if len(all_shapes) > 1:
|
|
||||||
multi_res = True
|
|
||||||
else:
|
|
||||||
multi_res = False
|
|
||||||
latents = torch.cat(latents, dim=0)
|
|
||||||
num_images = len(latents)
|
|
||||||
elif isinstance(latents, torch.Tensor):
|
|
||||||
latents = latents.to(dtype)
|
|
||||||
num_images = latents.shape[0]
|
|
||||||
multi_res = False
|
|
||||||
else:
|
|
||||||
logging.error(f"Invalid latents type: {type(latents)}")
|
|
||||||
num_images = 0
|
|
||||||
multi_res = False
|
|
||||||
|
|
||||||
return latents, num_images, multi_res
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
|
||||||
"""Validate conditioning count matches image count, expand if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
positive: Conditioning list
|
|
||||||
num_images: Number of images
|
|
||||||
bucket_mode: Whether bucket mode is enabled
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Validated/expanded conditioning list
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If conditioning count doesn't match image count
|
|
||||||
"""
|
|
||||||
if bucket_mode:
|
|
||||||
return positive # Skip validation in bucket mode
|
|
||||||
|
|
||||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
|
||||||
if len(positive) == 1 and num_images > 1:
|
|
||||||
return positive * num_images
|
|
||||||
elif len(positive) != num_images:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
|
||||||
)
|
|
||||||
return positive
|
|
||||||
|
|
||||||
|
|
||||||
def _load_existing_lora(existing_lora):
|
|
||||||
"""Load existing LoRA weights if provided.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
existing_lora: LoRA filename or "[None]"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (existing_weights dict, existing_steps int)
|
|
||||||
"""
|
|
||||||
if existing_lora == "[None]":
|
|
||||||
return {}, 0
|
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
|
||||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
|
||||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
|
||||||
existing_weights = {}
|
|
||||||
if lora_path:
|
|
||||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
|
||||||
return existing_weights, existing_steps
|
|
||||||
|
|
||||||
|
|
||||||
def _create_weight_adapter(
|
|
||||||
module, module_name, existing_weights, algorithm, lora_dtype, rank
|
|
||||||
):
|
|
||||||
"""Create a weight adapter for a module with weight.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module: The module to create adapter for
|
|
||||||
module_name: Name of the module
|
|
||||||
existing_weights: Dict of existing LoRA weights
|
|
||||||
algorithm: Algorithm name for new adapters
|
|
||||||
lora_dtype: dtype for LoRA weights
|
|
||||||
rank: Rank for new LoRA adapters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (train_adapter, lora_params dict)
|
|
||||||
"""
|
|
||||||
key = f"{module_name}.weight"
|
|
||||||
shape = module.weight.shape
|
|
||||||
lora_params = {}
|
|
||||||
|
|
||||||
if len(shape) >= 2:
|
|
||||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
|
||||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
|
||||||
|
|
||||||
# Try to load existing adapter
|
|
||||||
existing_adapter = None
|
|
||||||
for adapter_cls in adapters:
|
|
||||||
existing_adapter = adapter_cls.load(
|
|
||||||
module_name, existing_weights, alpha, dora_scale
|
|
||||||
)
|
|
||||||
if existing_adapter is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
if existing_adapter is None:
|
|
||||||
adapter_cls = adapter_maps[algorithm]
|
|
||||||
|
|
||||||
if existing_adapter is not None:
|
|
||||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
|
||||||
else:
|
|
||||||
# Use LoRA with alpha=1.0 by default
|
|
||||||
train_adapter = adapter_cls.create_train(
|
|
||||||
module.weight, rank=rank, alpha=1.0
|
|
||||||
).to(lora_dtype)
|
|
||||||
|
|
||||||
for name, parameter in train_adapter.named_parameters():
|
|
||||||
lora_params[f"{module_name}.{name}"] = parameter
|
|
||||||
|
|
||||||
return train_adapter.train().requires_grad_(True), lora_params
|
|
||||||
else:
|
|
||||||
# 1D weight - use BiasDiff
|
|
||||||
diff = torch.nn.Parameter(
|
|
||||||
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
|
||||||
)
|
|
||||||
diff_module = BiasDiff(diff).train().requires_grad_(True)
|
|
||||||
lora_params[f"{module_name}.diff"] = diff
|
|
||||||
return diff_module, lora_params
|
|
||||||
|
|
||||||
|
|
||||||
def _create_bias_adapter(module, module_name, lora_dtype):
|
|
||||||
"""Create a bias adapter for a module with bias.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module: The module with bias
|
|
||||||
module_name: Name of the module
|
|
||||||
lora_dtype: dtype for LoRA weights
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (bias_module, lora_params dict)
|
|
||||||
"""
|
|
||||||
bias = torch.nn.Parameter(
|
|
||||||
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
|
||||||
)
|
|
||||||
bias_module = BiasDiff(bias).train().requires_grad_(True)
|
|
||||||
lora_params = {f"{module_name}.diff_b": bias}
|
|
||||||
return bias_module, lora_params
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
|
||||||
"""Setup all LoRA adapters on the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mp: Model patcher
|
|
||||||
existing_weights: Dict of existing LoRA weights
|
|
||||||
algorithm: Algorithm name for new adapters
|
|
||||||
lora_dtype: dtype for LoRA weights
|
|
||||||
rank: Rank for new LoRA adapters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (lora_sd dict, all_weight_adapters list)
|
|
||||||
"""
|
|
||||||
lora_sd = {}
|
|
||||||
all_weight_adapters = []
|
|
||||||
|
|
||||||
for n, m in mp.model.named_modules():
|
|
||||||
if hasattr(m, "weight_function"):
|
|
||||||
if m.weight is not None:
|
|
||||||
adapter, params = _create_weight_adapter(
|
|
||||||
m, n, existing_weights, algorithm, lora_dtype, rank
|
|
||||||
)
|
|
||||||
lora_sd.update(params)
|
|
||||||
key = f"{n}.weight"
|
|
||||||
mp.add_weight_wrapper(key, adapter)
|
|
||||||
all_weight_adapters.append(adapter)
|
|
||||||
|
|
||||||
if hasattr(m, "bias") and m.bias is not None:
|
|
||||||
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
|
||||||
lora_sd.update(bias_params)
|
|
||||||
key = f"{n}.bias"
|
|
||||||
mp.add_weight_wrapper(key, bias_adapter)
|
|
||||||
all_weight_adapters.append(bias_adapter)
|
|
||||||
|
|
||||||
return lora_sd, all_weight_adapters
|
|
||||||
|
|
||||||
|
|
||||||
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
|
||||||
"""Create optimizer based on name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
|
|
||||||
parameters: Parameters to optimize
|
|
||||||
learning_rate: Learning rate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optimizer instance
|
|
||||||
"""
|
|
||||||
if optimizer_name == "Adam":
|
|
||||||
return torch.optim.Adam(parameters, lr=learning_rate)
|
|
||||||
elif optimizer_name == "AdamW":
|
|
||||||
return torch.optim.AdamW(parameters, lr=learning_rate)
|
|
||||||
elif optimizer_name == "SGD":
|
|
||||||
return torch.optim.SGD(parameters, lr=learning_rate)
|
|
||||||
elif optimizer_name == "RMSprop":
|
|
||||||
return torch.optim.RMSprop(parameters, lr=learning_rate)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_loss_function(loss_function_name):
|
|
||||||
"""Create loss function based on name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loss function instance
|
|
||||||
"""
|
|
||||||
if loss_function_name == "MSE":
|
|
||||||
return torch.nn.MSELoss()
|
|
||||||
elif loss_function_name == "L1":
|
|
||||||
return torch.nn.L1Loss()
|
|
||||||
elif loss_function_name == "Huber":
|
|
||||||
return torch.nn.HuberLoss()
|
|
||||||
elif loss_function_name == "SmoothL1":
|
|
||||||
return torch.nn.SmoothL1Loss()
|
|
||||||
|
|
||||||
|
|
||||||
def _run_training_loop(
|
|
||||||
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
|
||||||
):
|
|
||||||
"""Execute the training loop.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guider: The guider object
|
|
||||||
train_sampler: The training sampler
|
|
||||||
latents: Latent tensors
|
|
||||||
num_images: Number of images
|
|
||||||
seed: Random seed
|
|
||||||
bucket_mode: Whether bucket mode is enabled
|
|
||||||
multi_res: Whether multi-resolution mode is enabled
|
|
||||||
"""
|
|
||||||
sigmas = torch.tensor(range(num_images))
|
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
|
||||||
|
|
||||||
if bucket_mode:
|
|
||||||
# Use first bucket's first latent as dummy for guider
|
|
||||||
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
|
||||||
guider.sample(
|
|
||||||
noise.generate_noise({"samples": dummy_latent}),
|
|
||||||
dummy_latent,
|
|
||||||
train_sampler,
|
|
||||||
sigmas,
|
|
||||||
seed=noise.seed,
|
|
||||||
)
|
|
||||||
elif multi_res:
|
|
||||||
# use first latent as dummy latent if multi_res
|
|
||||||
latents = latents[0].repeat(num_images, 1, 1, 1)
|
|
||||||
guider.sample(
|
|
||||||
noise.generate_noise({"samples": latents}),
|
|
||||||
latents,
|
|
||||||
train_sampler,
|
|
||||||
sigmas,
|
|
||||||
seed=noise.seed,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
guider.sample(
|
|
||||||
noise.generate_noise({"samples": latents}),
|
|
||||||
latents,
|
|
||||||
train_sampler,
|
|
||||||
sigmas,
|
|
||||||
seed=noise.seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainLoraNode(io.ComfyNode):
|
class TrainLoraNode(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -879,11 +385,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="[None]",
|
default="[None]",
|
||||||
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
||||||
),
|
),
|
||||||
io.Boolean.Input(
|
|
||||||
"bucket_mode",
|
|
||||||
default=False,
|
|
||||||
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
io.Model.Output(
|
||||||
@ -918,7 +419,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
bucket_mode,
|
|
||||||
):
|
):
|
||||||
# Extract scalars from lists (due to is_input_list=True)
|
# Extract scalars from lists (due to is_input_list=True)
|
||||||
model = model[0]
|
model = model[0]
|
||||||
@ -927,125 +427,215 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
grad_accumulation_steps = grad_accumulation_steps[0]
|
grad_accumulation_steps = grad_accumulation_steps[0]
|
||||||
learning_rate = learning_rate[0]
|
learning_rate = learning_rate[0]
|
||||||
rank = rank[0]
|
rank = rank[0]
|
||||||
optimizer_name = optimizer[0]
|
optimizer = optimizer[0]
|
||||||
loss_function_name = loss_function[0]
|
loss_function = loss_function[0]
|
||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
|
||||||
|
|
||||||
# Process latents based on mode
|
# Handle latents - either single dict or list of dicts
|
||||||
if bucket_mode:
|
if len(latents) == 1:
|
||||||
latents = _process_latents_bucket_mode(latents)
|
latents = latents[0]["samples"] # Single latent dict
|
||||||
else:
|
else:
|
||||||
latents = _process_latents_standard_mode(latents)
|
latent_list = []
|
||||||
|
for latent in latents:
|
||||||
|
latent = latent["samples"]
|
||||||
|
bs = latent.shape[0]
|
||||||
|
if bs != 1:
|
||||||
|
for sub_latent in latent:
|
||||||
|
latent_list.append(sub_latent[None])
|
||||||
|
else:
|
||||||
|
latent_list.append(latent)
|
||||||
|
latents = latent_list
|
||||||
|
|
||||||
# Process conditioning
|
# Handle conditioning - either single list or list of lists
|
||||||
positive = _process_conditioning(positive)
|
if len(positive) == 1:
|
||||||
|
positive = positive[0] # Single conditioning list
|
||||||
|
else:
|
||||||
|
# Multiple conditioning lists - flatten
|
||||||
|
flat_positive = []
|
||||||
|
for cond in positive:
|
||||||
|
if isinstance(cond, list):
|
||||||
|
flat_positive.extend(cond)
|
||||||
|
else:
|
||||||
|
flat_positive.append(cond)
|
||||||
|
positive = flat_positive
|
||||||
|
|
||||||
# Setup model and dtype
|
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# latents here can be list of different size latent or one large batch
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
if isinstance(latents, list):
|
||||||
latents, dtype, bucket_mode
|
all_shapes = set()
|
||||||
)
|
latents = [t.to(dtype) for t in latents]
|
||||||
|
for latent in latents:
|
||||||
|
all_shapes.add(latent.shape)
|
||||||
|
logging.info(f"Latent shapes: {all_shapes}")
|
||||||
|
if len(all_shapes) > 1:
|
||||||
|
multi_res = True
|
||||||
|
else:
|
||||||
|
multi_res = False
|
||||||
|
latents = torch.cat(latents, dim=0)
|
||||||
|
num_images = len(latents)
|
||||||
|
elif isinstance(latents, torch.Tensor):
|
||||||
|
latents = latents.to(dtype)
|
||||||
|
num_images = latents.shape[0]
|
||||||
|
else:
|
||||||
|
logging.error(f"Invalid latents type: {type(latents)}")
|
||||||
|
|
||||||
# Validate and expand conditioning
|
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
if len(positive) == 1 and num_images > 1:
|
||||||
|
positive = positive * num_images
|
||||||
|
elif len(positive) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||||
|
)
|
||||||
|
|
||||||
with torch.inference_mode(False):
|
with torch.inference_mode(False):
|
||||||
# Setup models for training
|
lora_sd = {}
|
||||||
mp.model.requires_grad_(False)
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
|
||||||
# Load existing LoRA weights if provided
|
# Load existing LoRA weights if provided
|
||||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
existing_weights = {}
|
||||||
|
existing_steps = 0
|
||||||
|
if existing_lora != "[None]":
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||||
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||||
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||||
|
if lora_path:
|
||||||
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||||
|
|
||||||
# Setup LoRA adapters
|
all_weight_adapters = []
|
||||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
for n, m in mp.model.named_modules():
|
||||||
mp, existing_weights, algorithm, lora_dtype, rank
|
if hasattr(m, "weight_function"):
|
||||||
)
|
if m.weight is not None:
|
||||||
|
key = "{}.weight".format(n)
|
||||||
|
shape = m.weight.shape
|
||||||
|
if len(shape) >= 2:
|
||||||
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||||
|
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||||
|
for adapter_cls in adapters:
|
||||||
|
existing_adapter = adapter_cls.load(
|
||||||
|
n, existing_weights, alpha, dora_scale
|
||||||
|
)
|
||||||
|
if existing_adapter is not None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
existing_adapter = None
|
||||||
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
# Create optimizer and loss function
|
if existing_adapter is not None:
|
||||||
optimizer = _create_optimizer(
|
train_adapter = existing_adapter.to_train().to(
|
||||||
optimizer_name, lora_sd.values(), learning_rate
|
lora_dtype
|
||||||
)
|
)
|
||||||
criterion = _create_loss_function(loss_function_name)
|
else:
|
||||||
|
# Use LoRA with alpha=1.0 by default
|
||||||
|
train_adapter = adapter_cls.create_train(
|
||||||
|
m.weight, rank=rank, alpha=1.0
|
||||||
|
).to(lora_dtype)
|
||||||
|
for name, parameter in train_adapter.named_parameters():
|
||||||
|
lora_sd[f"{n}.{name}"] = parameter
|
||||||
|
|
||||||
# Setup gradient checkpointing
|
mp.add_weight_wrapper(key, train_adapter)
|
||||||
|
all_weight_adapters.append(train_adapter)
|
||||||
|
else:
|
||||||
|
diff = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
diff_module = BiasDiff(diff)
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||||
|
all_weight_adapters.append(diff_module)
|
||||||
|
lora_sd["{}.diff".format(n)] = diff
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
key = "{}.bias".format(n)
|
||||||
|
bias = torch.nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
m.bias.shape, dtype=lora_dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bias_module = BiasDiff(bias)
|
||||||
|
lora_sd["{}.diff_b".format(n)] = bias
|
||||||
|
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||||
|
all_weight_adapters.append(bias_module)
|
||||||
|
|
||||||
|
if optimizer == "Adam":
|
||||||
|
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "AdamW":
|
||||||
|
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "SGD":
|
||||||
|
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||||
|
elif optimizer == "RMSprop":
|
||||||
|
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||||
|
|
||||||
|
# Setup loss function based on selection
|
||||||
|
if loss_function == "MSE":
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
elif loss_function == "L1":
|
||||||
|
criterion = torch.nn.L1Loss()
|
||||||
|
elif loss_function == "Huber":
|
||||||
|
criterion = torch.nn.HuberLoss()
|
||||||
|
elif loss_function == "SmoothL1":
|
||||||
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
# setup models
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(
|
for m in find_all_highest_child_module_with_forward(
|
||||||
mp.model.diffusion_model
|
mp.model.diffusion_model
|
||||||
):
|
):
|
||||||
patch(m)
|
patch(m)
|
||||||
|
mp.model.requires_grad_(False)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
# With force_full_load=False we should be able to have offloading
|
|
||||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=True
|
[mp], memory_required=1e20, force_full_load=True
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Setup loss tracking
|
# Setup sampler and guider like in test script
|
||||||
loss_map = {"loss": []}
|
loss_map = {"loss": []}
|
||||||
|
|
||||||
def loss_callback(loss):
|
def loss_callback(loss):
|
||||||
loss_map["loss"].append(loss)
|
loss_map["loss"].append(loss)
|
||||||
|
|
||||||
# Create sampler
|
train_sampler = TrainSampler(
|
||||||
if bucket_mode:
|
criterion,
|
||||||
train_sampler = TrainSampler(
|
optimizer,
|
||||||
criterion,
|
loss_callback=loss_callback,
|
||||||
optimizer,
|
batch_size=batch_size,
|
||||||
loss_callback=loss_callback,
|
grad_acc=grad_accumulation_steps,
|
||||||
batch_size=batch_size,
|
total_steps=steps * grad_accumulation_steps,
|
||||||
grad_acc=grad_accumulation_steps,
|
seed=seed,
|
||||||
total_steps=steps * grad_accumulation_steps,
|
training_dtype=dtype,
|
||||||
seed=seed,
|
real_dataset=latents if multi_res else None,
|
||||||
training_dtype=dtype,
|
)
|
||||||
bucket_latents=latents,
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
)
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
else:
|
|
||||||
train_sampler = TrainSampler(
|
|
||||||
criterion,
|
|
||||||
optimizer,
|
|
||||||
loss_callback=loss_callback,
|
|
||||||
batch_size=batch_size,
|
|
||||||
grad_acc=grad_accumulation_steps,
|
|
||||||
total_steps=steps * grad_accumulation_steps,
|
|
||||||
seed=seed,
|
|
||||||
training_dtype=dtype,
|
|
||||||
real_dataset=latents if multi_res else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup guider
|
# Training loop
|
||||||
guider = TrainGuider(mp)
|
|
||||||
guider.set_conds(positive)
|
|
||||||
|
|
||||||
# Run training loop
|
|
||||||
try:
|
try:
|
||||||
_run_training_loop(
|
# Generate dummy sigmas and noise
|
||||||
guider,
|
sigmas = torch.tensor(range(num_images))
|
||||||
train_sampler,
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
|
if multi_res:
|
||||||
|
# use first latent as dummy latent if multi_res
|
||||||
|
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
|
||||||
|
guider.sample(
|
||||||
|
noise.generate_noise({"samples": latents}),
|
||||||
latents,
|
latents,
|
||||||
num_images,
|
train_sampler,
|
||||||
seed,
|
sigmas,
|
||||||
bucket_mode,
|
seed=noise.seed,
|
||||||
multi_res,
|
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
|
|
||||||
# Finalize adapters
|
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
adapter.requires_grad_(False)
|
adapter.requires_grad_(False)
|
||||||
|
|
||||||
@ -1055,7 +645,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
class LoraModelLoader(io.ComfyNode):#
|
class LoraModelLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
|
|||||||
@ -1,535 +0,0 @@
|
|||||||
import nodes
|
|
||||||
import node_helpers
|
|
||||||
import torch
|
|
||||||
import torchvision.transforms.functional as TF
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.utils
|
|
||||||
import numpy as np
|
|
||||||
from typing_extensions import override
|
|
||||||
from comfy_api.latest import ComfyExtension, io
|
|
||||||
from comfy_extras.nodes_wan import parse_json_tracks
|
|
||||||
|
|
||||||
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
|
|
||||||
SKIP_ZERO = False
|
|
||||||
|
|
||||||
def get_pos_emb(
|
|
||||||
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
|
|
||||||
pos_emb_dim: int,
|
|
||||||
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
|
|
||||||
|
|
||||||
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
|
|
||||||
pos_k = pos_k.to(device, dtype)
|
|
||||||
if SKIP_ZERO:
|
|
||||||
pos_k = pos_k + 1
|
|
||||||
batch_size = pos_k.size(0)
|
|
||||||
|
|
||||||
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
|
|
||||||
# Expand denominator to match the shape needed for broadcasting
|
|
||||||
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
|
|
||||||
|
|
||||||
thetas = theta_func(denominator_expanded, pos_emb_dim)
|
|
||||||
|
|
||||||
# Ensure pos_k is in the correct shape for broadcasting
|
|
||||||
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
|
|
||||||
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
|
|
||||||
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
|
|
||||||
|
|
||||||
# Concatenate sine and cosine embeddings along the last dimension
|
|
||||||
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
|
|
||||||
|
|
||||||
return pos_emb
|
|
||||||
|
|
||||||
def create_pos_embeddings(
|
|
||||||
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
|
|
||||||
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
|
|
||||||
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
|
|
||||||
height: int, # the height of the feature map
|
|
||||||
width: int, # the width of the feature map
|
|
||||||
track_num: int = -1, # the number of tracks to use
|
|
||||||
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
|
|
||||||
):
|
|
||||||
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
|
|
||||||
|
|
||||||
t, n, _ = pred_tracks.shape
|
|
||||||
t_down, h_down, w_down = downsample_ratios
|
|
||||||
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
|
|
||||||
|
|
||||||
if track_num == -1:
|
|
||||||
track_num = n
|
|
||||||
|
|
||||||
tracks_idx = torch.randperm(n)[:track_num]
|
|
||||||
tracks = pred_tracks[:, tracks_idx]
|
|
||||||
visibility = pred_visibility[:, tracks_idx]
|
|
||||||
|
|
||||||
for t_idx in range(0, t, t_down):
|
|
||||||
if t_down_strategy == "sample" or t_idx == 0:
|
|
||||||
cur_tracks = tracks[t_idx] # [N, 2]
|
|
||||||
cur_visibility = visibility[t_idx] # [N]
|
|
||||||
else:
|
|
||||||
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
|
|
||||||
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
|
|
||||||
|
|
||||||
for i in range(track_num):
|
|
||||||
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
|
|
||||||
continue
|
|
||||||
x, y = cur_tracks[i]
|
|
||||||
x, y = int(x // w_down), int(y // h_down)
|
|
||||||
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
|
|
||||||
|
|
||||||
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
|
|
||||||
|
|
||||||
def replace_feature(
|
|
||||||
vae_feature: torch.Tensor, # [B, C', T', H', W']
|
|
||||||
track_pos: torch.Tensor, # [B, N, T', 2]
|
|
||||||
strength: float = 1.0
|
|
||||||
) -> torch.Tensor:
|
|
||||||
b, _, t, h, w = vae_feature.shape
|
|
||||||
assert b == track_pos.shape[0], "Batch size mismatch."
|
|
||||||
n = track_pos.shape[1]
|
|
||||||
|
|
||||||
# Shuffle the trajectory order
|
|
||||||
track_pos = track_pos[:, torch.randperm(n), :, :]
|
|
||||||
|
|
||||||
# Extract coordinates at time steps ≥ 1 and generate a valid mask
|
|
||||||
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
|
|
||||||
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
|
|
||||||
|
|
||||||
# Get all valid indices
|
|
||||||
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
|
|
||||||
num_valid = valid_indices.shape[0]
|
|
||||||
|
|
||||||
if num_valid == 0:
|
|
||||||
return vae_feature
|
|
||||||
|
|
||||||
# Decompose valid indices into each dimension
|
|
||||||
batch_idx = valid_indices[:, 0]
|
|
||||||
track_idx = valid_indices[:, 1]
|
|
||||||
t_rel = valid_indices[:, 2]
|
|
||||||
t_target = t_rel + 1 # Convert to original time step indices
|
|
||||||
|
|
||||||
# Extract target position coordinates
|
|
||||||
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
|
|
||||||
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
|
|
||||||
|
|
||||||
# Extract source position coordinates (t=0)
|
|
||||||
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
|
|
||||||
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
|
|
||||||
|
|
||||||
# Get source features and assign to target positions
|
|
||||||
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
|
|
||||||
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
|
|
||||||
|
|
||||||
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
|
|
||||||
|
|
||||||
|
|
||||||
return vae_feature
|
|
||||||
|
|
||||||
# Visualize functions
|
|
||||||
|
|
||||||
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
|
|
||||||
draw = ImageDraw.Draw(overlay, 'RGBA')
|
|
||||||
points = points[::-1]
|
|
||||||
|
|
||||||
# Compute total length
|
|
||||||
total_length = 0
|
|
||||||
segment_lengths = []
|
|
||||||
for i in range(len(points) - 1):
|
|
||||||
dx = points[i + 1][0] - points[i][0]
|
|
||||||
dy = points[i + 1][1] - points[i][1]
|
|
||||||
length = (dx * dx + dy * dy) ** 0.5
|
|
||||||
segment_lengths.append(length)
|
|
||||||
total_length += length
|
|
||||||
|
|
||||||
if total_length == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
accumulated_length = 0
|
|
||||||
|
|
||||||
# Draw the gradient polyline
|
|
||||||
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
|
|
||||||
segment_length = segment_lengths[idx]
|
|
||||||
steps = max(int(segment_length), 1)
|
|
||||||
|
|
||||||
for i in range(steps):
|
|
||||||
current_length = accumulated_length + (i / steps) * segment_length
|
|
||||||
ratio = current_length / total_length
|
|
||||||
|
|
||||||
alpha = int(255 * (1 - ratio) * opacity)
|
|
||||||
color = (*start_color, alpha)
|
|
||||||
|
|
||||||
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
|
|
||||||
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
|
|
||||||
|
|
||||||
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
|
|
||||||
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
|
|
||||||
|
|
||||||
accumulated_length += segment_length
|
|
||||||
|
|
||||||
|
|
||||||
def add_weighted(rgb, track):
|
|
||||||
rgb = np.array(rgb) # [H, W, C] "RGB"
|
|
||||||
track = np.array(track) # [H, W, C] "RGBA"
|
|
||||||
|
|
||||||
alpha = track[:, :, 3] / 255.0
|
|
||||||
alpha = np.stack([alpha] * 3, axis=-1)
|
|
||||||
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
|
|
||||||
|
|
||||||
return Image.fromarray(blend_img.astype(np.uint8))
|
|
||||||
|
|
||||||
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
|
|
||||||
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
|
|
||||||
|
|
||||||
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
|
|
||||||
tracks = tracks[0].long().detach().cpu().numpy()
|
|
||||||
if visibility is not None:
|
|
||||||
visibility = visibility[0].detach().cpu().numpy()
|
|
||||||
|
|
||||||
num_frames, height, width = video.shape[:3]
|
|
||||||
num_tracks = tracks.shape[1]
|
|
||||||
alpha_opacity = int(255 * opacity)
|
|
||||||
|
|
||||||
output_frames = []
|
|
||||||
for t in range(num_frames):
|
|
||||||
frame_rgb = video[t].astype(np.float32)
|
|
||||||
|
|
||||||
# Create a single RGBA overlay for all tracks in this frame
|
|
||||||
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
|
||||||
draw_overlay = ImageDraw.Draw(overlay)
|
|
||||||
|
|
||||||
polyline_data = []
|
|
||||||
|
|
||||||
# Draw all circles on a single overlay
|
|
||||||
for n in range(num_tracks):
|
|
||||||
if visibility is not None and visibility[t, n] == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
track_coord = tracks[t, n]
|
|
||||||
color = color_map[n % len(color_map)]
|
|
||||||
circle_color = color + (alpha_opacity,)
|
|
||||||
|
|
||||||
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
|
|
||||||
fill=circle_color
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store polyline data for batch processing
|
|
||||||
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
|
|
||||||
if len(tracks_coord) > 1:
|
|
||||||
polyline_data.append((tracks_coord, color))
|
|
||||||
|
|
||||||
# Blend circles overlay once
|
|
||||||
overlay_np = np.array(overlay)
|
|
||||||
alpha = overlay_np[:, :, 3:4] / 255.0
|
|
||||||
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
|
||||||
|
|
||||||
# Draw all polylines on a single overlay
|
|
||||||
if polyline_data:
|
|
||||||
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
|
||||||
for tracks_coord, color in polyline_data:
|
|
||||||
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
|
|
||||||
|
|
||||||
# Blend polylines overlay once
|
|
||||||
polyline_np = np.array(polyline_overlay)
|
|
||||||
alpha = polyline_np[:, :, 3:4] / 255.0
|
|
||||||
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
|
||||||
|
|
||||||
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
|
|
||||||
|
|
||||||
return output_frames
|
|
||||||
|
|
||||||
|
|
||||||
class WanMoveVisualizeTracks(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="WanMoveVisualizeTracks",
|
|
||||||
category="conditioning/video_models",
|
|
||||||
inputs=[
|
|
||||||
io.Image.Input("images"),
|
|
||||||
io.Tracks.Input("tracks", optional=True),
|
|
||||||
io.Int.Input("line_resolution", default=24, min=1, max=1024),
|
|
||||||
io.Int.Input("circle_size", default=12, min=1, max=128),
|
|
||||||
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
|
|
||||||
io.Int.Input("line_width", default=16, min=1, max=128),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Image.Output(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
|
|
||||||
if tracks is None:
|
|
||||||
return io.NodeOutput(images)
|
|
||||||
|
|
||||||
track_path = tracks["track_path"].unsqueeze(0)
|
|
||||||
track_visibility = tracks["track_visibility"].unsqueeze(0)
|
|
||||||
images_in = images * 255.0
|
|
||||||
if images_in.shape[0] != track_path.shape[1]:
|
|
||||||
repeat_count = track_path.shape[1] // images.shape[0]
|
|
||||||
images_in = images_in.repeat(repeat_count, 1, 1, 1)
|
|
||||||
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
|
|
||||||
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
|
|
||||||
|
|
||||||
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
|
|
||||||
|
|
||||||
|
|
||||||
class WanMoveTracksFromCoords(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="WanMoveTracksFromCoords",
|
|
||||||
category="conditioning/video_models",
|
|
||||||
inputs=[
|
|
||||||
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
|
|
||||||
io.Mask.Input("track_mask", optional=True),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Tracks.Output(),
|
|
||||||
io.Int.Output(display_name="track_length"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
|
|
||||||
device=comfy.model_management.intermediate_device()
|
|
||||||
|
|
||||||
tracks_data = parse_json_tracks(track_coords)
|
|
||||||
track_length = len(tracks_data[0])
|
|
||||||
|
|
||||||
track_list = [
|
|
||||||
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
|
|
||||||
for frame in range(len(tracks_data[0]))
|
|
||||||
]
|
|
||||||
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
|
||||||
|
|
||||||
num_tracks = tracks.shape[-2]
|
|
||||||
if track_mask is None:
|
|
||||||
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
|
||||||
else:
|
|
||||||
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
|
||||||
|
|
||||||
out_track_info = {}
|
|
||||||
out_track_info["track_path"] = tracks
|
|
||||||
out_track_info["track_visibility"] = track_visibility
|
|
||||||
return io.NodeOutput(out_track_info, track_length)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateTracks(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="GenerateTracks",
|
|
||||||
category="conditioning/video_models",
|
|
||||||
inputs=[
|
|
||||||
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
|
||||||
io.Int.Input("height", default=480, min=16, max=4096, step=16),
|
|
||||||
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
|
|
||||||
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
|
|
||||||
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
|
|
||||||
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
|
|
||||||
io.Int.Input("num_frames", default=81, min=1, max=1024),
|
|
||||||
io.Int.Input("num_tracks", default=5, min=1, max=100),
|
|
||||||
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
|
|
||||||
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
|
|
||||||
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
|
|
||||||
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
|
|
||||||
io.Combo.Input(
|
|
||||||
"interpolation",
|
|
||||||
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
|
|
||||||
tooltip="Controls the timing/speed of movement along the path.",
|
|
||||||
),
|
|
||||||
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Tracks.Output(),
|
|
||||||
io.Int.Output(display_name="track_length"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
|
|
||||||
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
|
|
||||||
device = comfy.model_management.intermediate_device()
|
|
||||||
track_length = num_frames
|
|
||||||
|
|
||||||
# normalized coordinates to pixel coordinates
|
|
||||||
start_x_px = start_x * width
|
|
||||||
start_y_px = start_y * height
|
|
||||||
mid_x_px = mid_x * width
|
|
||||||
mid_y_px = mid_y * height
|
|
||||||
end_x_px = end_x * width
|
|
||||||
end_y_px = end_y * height
|
|
||||||
|
|
||||||
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
|
|
||||||
|
|
||||||
t = torch.linspace(0, 1, num_frames, device=device)
|
|
||||||
if interpolation == "constant": # All points stay at start position
|
|
||||||
interp_values = torch.zeros_like(t)
|
|
||||||
elif interpolation == "linear":
|
|
||||||
interp_values = t
|
|
||||||
elif interpolation == "ease_in":
|
|
||||||
interp_values = t ** 2
|
|
||||||
elif interpolation == "ease_out":
|
|
||||||
interp_values = 1 - (1 - t) ** 2
|
|
||||||
elif interpolation == "ease_in_out":
|
|
||||||
interp_values = t * t * (3 - 2 * t)
|
|
||||||
|
|
||||||
if bezier: # apply interpolation to t for timing control along the bezier path
|
|
||||||
t_interp = interp_values
|
|
||||||
one_minus_t = 1 - t_interp
|
|
||||||
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
|
|
||||||
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
|
|
||||||
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
|
|
||||||
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
|
|
||||||
else: # calculate base x and y positions for each frame (center track)
|
|
||||||
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
|
|
||||||
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
|
|
||||||
# For non-bezier, tangent is constant (direction from start to end)
|
|
||||||
tangent_x = torch.full_like(t, end_x_px - start_x_px)
|
|
||||||
tangent_y = torch.full_like(t, end_y_px - start_y_px)
|
|
||||||
|
|
||||||
track_list = []
|
|
||||||
for frame_idx in range(num_frames):
|
|
||||||
# Calculate perpendicular direction at this frame
|
|
||||||
tx = tangent_x[frame_idx].item()
|
|
||||||
ty = tangent_y[frame_idx].item()
|
|
||||||
length = (tx ** 2 + ty ** 2) ** 0.5
|
|
||||||
|
|
||||||
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
|
|
||||||
perp_x = -ty / length
|
|
||||||
perp_y = tx / length
|
|
||||||
else: # If tangent is zero, spread horizontally
|
|
||||||
perp_x = 1.0
|
|
||||||
perp_y = 0.0
|
|
||||||
|
|
||||||
frame_tracks = []
|
|
||||||
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
|
|
||||||
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
|
|
||||||
track_x = x_positions[frame_idx].item() + perp_x * offset
|
|
||||||
track_y = y_positions[frame_idx].item() + perp_y * offset
|
|
||||||
frame_tracks.append([track_x, track_y])
|
|
||||||
track_list.append(frame_tracks)
|
|
||||||
|
|
||||||
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
|
||||||
|
|
||||||
if track_mask is None:
|
|
||||||
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
|
||||||
else:
|
|
||||||
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
|
||||||
|
|
||||||
out_track_info = {}
|
|
||||||
out_track_info["track_path"] = tracks
|
|
||||||
out_track_info["track_visibility"] = track_visibility
|
|
||||||
return io.NodeOutput(out_track_info, track_length)
|
|
||||||
|
|
||||||
|
|
||||||
class WanMoveConcatTrack(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="WanMoveConcatTrack",
|
|
||||||
category="conditioning/video_models",
|
|
||||||
inputs=[
|
|
||||||
io.Tracks.Input("tracks_1"),
|
|
||||||
io.Tracks.Input("tracks_2", optional=True),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Tracks.Output(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
|
|
||||||
if tracks_2 is None:
|
|
||||||
return io.NodeOutput(tracks_1)
|
|
||||||
|
|
||||||
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
|
|
||||||
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
|
|
||||||
|
|
||||||
out_track_info = {}
|
|
||||||
out_track_info["track_path"] = tracks_out
|
|
||||||
out_track_info["track_visibility"] = mask_out
|
|
||||||
return io.NodeOutput(out_track_info)
|
|
||||||
|
|
||||||
|
|
||||||
class WanMoveTrackToVideo(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="WanMoveTrackToVideo",
|
|
||||||
category="conditioning/video_models",
|
|
||||||
inputs=[
|
|
||||||
io.Conditioning.Input("positive"),
|
|
||||||
io.Conditioning.Input("negative"),
|
|
||||||
io.Vae.Input("vae"),
|
|
||||||
io.Tracks.Input("tracks", optional=True),
|
|
||||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
|
|
||||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
|
||||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
|
||||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
|
||||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
|
||||||
io.Image.Input("start_image"),
|
|
||||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Conditioning.Output(display_name="positive"),
|
|
||||||
io.Conditioning.Output(display_name="negative"),
|
|
||||||
io.Latent.Output(display_name="latent"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
|
||||||
device=comfy.model_management.intermediate_device()
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
|
|
||||||
if start_image is not None:
|
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
|
||||||
image[:start_image.shape[0]] = start_image
|
|
||||||
|
|
||||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
|
||||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
|
||||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
|
||||||
|
|
||||||
if tracks is not None and strength > 0.0:
|
|
||||||
tracks_path = tracks["track_path"][:length] # [T, N, 2]
|
|
||||||
num_tracks = tracks_path.shape[-2]
|
|
||||||
|
|
||||||
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
|
|
||||||
|
|
||||||
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
|
|
||||||
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
|
|
||||||
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
|
|
||||||
else:
|
|
||||||
concat_latent_image_pos = concat_latent_image
|
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
|
||||||
|
|
||||||
if clip_vision_output is not None:
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
|
||||||
return io.NodeOutput(positive, negative, out_latent)
|
|
||||||
|
|
||||||
|
|
||||||
class WanMoveExtension(ComfyExtension):
|
|
||||||
@override
|
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
||||||
return [
|
|
||||||
WanMoveTrackToVideo,
|
|
||||||
WanMoveTracksFromCoords,
|
|
||||||
WanMoveConcatTrack,
|
|
||||||
WanMoveVisualizeTracks,
|
|
||||||
GenerateTracks,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanMoveExtension:
|
|
||||||
return WanMoveExtension()
|
|
||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.5.1"
|
__version__ = "0.4.0"
|
||||||
|
|||||||
@ -13,7 +13,6 @@ import asyncio
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from latent_preview import set_preview_method
|
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_execution.caching import (
|
from comfy_execution.caching import (
|
||||||
BasicCache,
|
BasicCache,
|
||||||
@ -670,8 +669,6 @@ class PromptExecutor:
|
|||||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||||
|
|
||||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
|
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
|
|||||||
@ -8,8 +8,6 @@ import folder_paths
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
default_preview_method = args.preview_method
|
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||||
|
|
||||||
@ -127,11 +125,3 @@ def prepare_callback(model, steps, x0_output_dict=None):
|
|||||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
def set_preview_method(override: str = None):
|
|
||||||
if override and override != "default":
|
|
||||||
method = LatentPreviewMethod.from_string(override)
|
|
||||||
if method is not None:
|
|
||||||
args.preview_method = method
|
|
||||||
return
|
|
||||||
args.preview_method = default_preview_method
|
|
||||||
|
|
||||||
|
|||||||
66
main.py
66
main.py
@ -23,38 +23,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
if os.name == "nt":
|
|
||||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
|
||||||
if args.default_device is not None:
|
|
||||||
default_dev = args.default_device
|
|
||||||
devices = list(range(32))
|
|
||||||
devices.remove(default_dev)
|
|
||||||
devices.insert(0, default_dev)
|
|
||||||
devices = ','.join(map(str, devices))
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
|
||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
|
||||||
|
|
||||||
if args.cuda_device is not None:
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
|
||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
|
||||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
|
||||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
|
||||||
|
|
||||||
if args.oneapi_device_selector is not None:
|
|
||||||
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
|
||||||
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
|
||||||
|
|
||||||
if args.deterministic:
|
|
||||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
|
||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
|
||||||
|
|
||||||
import cuda_malloc
|
|
||||||
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
|
||||||
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
|
||||||
|
|
||||||
|
|
||||||
def handle_comfyui_manager_unavailable():
|
def handle_comfyui_manager_unavailable():
|
||||||
if not args.windows_standalone_build:
|
if not args.windows_standalone_build:
|
||||||
@ -169,6 +137,40 @@ import shutil
|
|||||||
import threading
|
import threading
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
|
||||||
|
if os.name == "nt":
|
||||||
|
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
||||||
|
if args.default_device is not None:
|
||||||
|
default_dev = args.default_device
|
||||||
|
devices = list(range(32))
|
||||||
|
devices.remove(default_dev)
|
||||||
|
devices.insert(0, default_dev)
|
||||||
|
devices = ','.join(map(str, devices))
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
||||||
|
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
||||||
|
|
||||||
|
if args.cuda_device is not None:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
|
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
|
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||||
|
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||||
|
|
||||||
|
if args.oneapi_device_selector is not None:
|
||||||
|
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
||||||
|
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
||||||
|
|
||||||
|
if args.deterministic:
|
||||||
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||||
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||||
|
|
||||||
|
import cuda_malloc
|
||||||
|
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||||
|
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||||
|
|
||||||
|
|
||||||
if 'torch' in sys.modules:
|
if 'torch' in sys.modules:
|
||||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.0.3b5
|
comfyui_manager==4.0.3b4
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -2358,7 +2358,6 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
"nodes_kandinsky5.py",
|
"nodes_kandinsky5.py",
|
||||||
"nodes_wanmove.py",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
@ -2384,6 +2383,7 @@ async def init_builtin_api_nodes():
|
|||||||
"nodes_recraft.py",
|
"nodes_recraft.py",
|
||||||
"nodes_pixverse.py",
|
"nodes_pixverse.py",
|
||||||
"nodes_stability.py",
|
"nodes_stability.py",
|
||||||
|
"nodes_pika.py",
|
||||||
"nodes_runway.py",
|
"nodes_runway.py",
|
||||||
"nodes_sora.py",
|
"nodes_sora.py",
|
||||||
"nodes_topaz.py",
|
"nodes_topaz.py",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.5.1"
|
version = "0.4.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.34.9
|
comfyui-frontend-package==1.33.13
|
||||||
comfyui-workflow-templates==0.7.59
|
comfyui-workflow-templates==0.7.54
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
135
server.py
135
server.py
@ -7,7 +7,6 @@ import time
|
|||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import execution
|
import execution
|
||||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
|
||||||
import uuid
|
import uuid
|
||||||
import urllib
|
import urllib
|
||||||
import json
|
import json
|
||||||
@ -48,12 +47,6 @@ from middleware.cache_middleware import cache_control
|
|||||||
if args.enable_manager:
|
if args.enable_manager:
|
||||||
import comfyui_manager
|
import comfyui_manager
|
||||||
|
|
||||||
|
|
||||||
def _remove_sensitive_from_queue(queue: list) -> list:
|
|
||||||
"""Remove sensitive data (index 5) from queue item tuples."""
|
|
||||||
return [item[:5] for item in queue]
|
|
||||||
|
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
@ -701,129 +694,6 @@ class PromptServer():
|
|||||||
out[node_class] = node_info(node_class)
|
out[node_class] = node_info(node_class)
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/api/jobs")
|
|
||||||
async def get_jobs(request):
|
|
||||||
"""List all jobs with filtering, sorting, and pagination.
|
|
||||||
|
|
||||||
Query parameters:
|
|
||||||
status: Filter by status (comma-separated): pending, in_progress, completed, failed
|
|
||||||
workflow_id: Filter by workflow ID
|
|
||||||
sort_by: Sort field: created_at (default), execution_duration
|
|
||||||
sort_order: Sort direction: asc, desc (default)
|
|
||||||
limit: Max items to return (positive integer)
|
|
||||||
offset: Items to skip (non-negative integer, default 0)
|
|
||||||
"""
|
|
||||||
query = request.rel_url.query
|
|
||||||
|
|
||||||
status_param = query.get('status')
|
|
||||||
workflow_id = query.get('workflow_id')
|
|
||||||
sort_by = query.get('sort_by', 'created_at').lower()
|
|
||||||
sort_order = query.get('sort_order', 'desc').lower()
|
|
||||||
|
|
||||||
status_filter = None
|
|
||||||
if status_param:
|
|
||||||
status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()]
|
|
||||||
invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL]
|
|
||||||
if invalid_statuses:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
if sort_by not in {'created_at', 'execution_duration'}:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
if sort_order not in {'asc', 'desc'}:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "sort_order must be 'asc' or 'desc'"},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
limit = None
|
|
||||||
|
|
||||||
# If limit is provided, validate that it is a positive integer, else continue without a limit
|
|
||||||
if 'limit' in query:
|
|
||||||
try:
|
|
||||||
limit = int(query.get('limit'))
|
|
||||||
if limit <= 0:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "limit must be a positive integer"},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "limit must be an integer"},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
offset = 0
|
|
||||||
if 'offset' in query:
|
|
||||||
try:
|
|
||||||
offset = int(query.get('offset'))
|
|
||||||
if offset < 0:
|
|
||||||
offset = 0
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "offset must be an integer"},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
|
||||||
history = self.prompt_queue.get_history()
|
|
||||||
|
|
||||||
running = _remove_sensitive_from_queue(running)
|
|
||||||
queued = _remove_sensitive_from_queue(queued)
|
|
||||||
|
|
||||||
jobs, total = get_all_jobs(
|
|
||||||
running, queued, history,
|
|
||||||
status_filter=status_filter,
|
|
||||||
workflow_id=workflow_id,
|
|
||||||
sort_by=sort_by,
|
|
||||||
sort_order=sort_order,
|
|
||||||
limit=limit,
|
|
||||||
offset=offset
|
|
||||||
)
|
|
||||||
|
|
||||||
has_more = (offset + len(jobs)) < total
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'jobs': jobs,
|
|
||||||
'pagination': {
|
|
||||||
'offset': offset,
|
|
||||||
'limit': limit,
|
|
||||||
'total': total,
|
|
||||||
'has_more': has_more
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
@routes.get("/api/jobs/{job_id}")
|
|
||||||
async def get_job_by_id(request):
|
|
||||||
"""Get a single job by ID."""
|
|
||||||
job_id = request.match_info.get("job_id", None)
|
|
||||||
if not job_id:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "job_id is required"},
|
|
||||||
status=400
|
|
||||||
)
|
|
||||||
|
|
||||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
|
||||||
history = self.prompt_queue.get_history(prompt_id=job_id)
|
|
||||||
|
|
||||||
running = _remove_sensitive_from_queue(running)
|
|
||||||
queued = _remove_sensitive_from_queue(queued)
|
|
||||||
|
|
||||||
job = get_job(job_id, running, queued, history)
|
|
||||||
if job is None:
|
|
||||||
return web.json_response(
|
|
||||||
{"error": "Job not found"},
|
|
||||||
status=404
|
|
||||||
)
|
|
||||||
|
|
||||||
return web.json_response(job)
|
|
||||||
|
|
||||||
@routes.get("/history")
|
@routes.get("/history")
|
||||||
async def get_history(request):
|
async def get_history(request):
|
||||||
max_items = request.rel_url.query.get("max_items", None)
|
max_items = request.rel_url.query.get("max_items", None)
|
||||||
@ -847,8 +717,9 @@ class PromptServer():
|
|||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||||
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
|
remove_sensitive = lambda queue: [x[:5] for x in queue]
|
||||||
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
|
queue_info['queue_running'] = remove_sensitive(current_queue[0])
|
||||||
|
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.post("/prompt")
|
||||||
|
|||||||
@ -1,352 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for Queue-specific Preview Method Override feature.
|
|
||||||
|
|
||||||
Tests the preview method override functionality:
|
|
||||||
- LatentPreviewMethod.from_string() method
|
|
||||||
- set_preview_method() function in latent_preview.py
|
|
||||||
- default_preview_method variable
|
|
||||||
- Integration with args.preview_method
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from comfy.cli_args import args, LatentPreviewMethod
|
|
||||||
from latent_preview import set_preview_method, default_preview_method
|
|
||||||
|
|
||||||
|
|
||||||
class TestLatentPreviewMethodFromString:
|
|
||||||
"""Test LatentPreviewMethod.from_string() classmethod."""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("value,expected", [
|
|
||||||
("auto", LatentPreviewMethod.Auto),
|
|
||||||
("latent2rgb", LatentPreviewMethod.Latent2RGB),
|
|
||||||
("taesd", LatentPreviewMethod.TAESD),
|
|
||||||
("none", LatentPreviewMethod.NoPreviews),
|
|
||||||
])
|
|
||||||
def test_valid_values_return_enum(self, value, expected):
|
|
||||||
"""Valid string values should return corresponding enum."""
|
|
||||||
assert LatentPreviewMethod.from_string(value) == expected
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("invalid", [
|
|
||||||
"invalid",
|
|
||||||
"TAESD", # Case sensitive
|
|
||||||
"AUTO", # Case sensitive
|
|
||||||
"Latent2RGB", # Case sensitive
|
|
||||||
"latent",
|
|
||||||
"",
|
|
||||||
"default", # default is special, not a method
|
|
||||||
])
|
|
||||||
def test_invalid_values_return_none(self, invalid):
|
|
||||||
"""Invalid string values should return None."""
|
|
||||||
assert LatentPreviewMethod.from_string(invalid) is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestLatentPreviewMethodEnumValues:
|
|
||||||
"""Test LatentPreviewMethod enum has expected values."""
|
|
||||||
|
|
||||||
def test_enum_values(self):
|
|
||||||
"""Verify enum values match expected strings."""
|
|
||||||
assert LatentPreviewMethod.NoPreviews.value == "none"
|
|
||||||
assert LatentPreviewMethod.Auto.value == "auto"
|
|
||||||
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
|
|
||||||
assert LatentPreviewMethod.TAESD.value == "taesd"
|
|
||||||
|
|
||||||
def test_enum_count(self):
|
|
||||||
"""Verify exactly 4 preview methods exist."""
|
|
||||||
assert len(LatentPreviewMethod) == 4
|
|
||||||
|
|
||||||
|
|
||||||
class TestSetPreviewMethod:
|
|
||||||
"""Test set_preview_method() function from latent_preview.py."""
|
|
||||||
|
|
||||||
def setup_method(self):
|
|
||||||
"""Store original value before each test."""
|
|
||||||
self.original = args.preview_method
|
|
||||||
|
|
||||||
def teardown_method(self):
|
|
||||||
"""Restore original value after each test."""
|
|
||||||
args.preview_method = self.original
|
|
||||||
|
|
||||||
def test_override_with_taesd(self):
|
|
||||||
"""'taesd' should set args.preview_method to TAESD."""
|
|
||||||
set_preview_method("taesd")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
def test_override_with_latent2rgb(self):
|
|
||||||
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
|
|
||||||
set_preview_method("latent2rgb")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
|
||||||
|
|
||||||
def test_override_with_auto(self):
|
|
||||||
"""'auto' should set args.preview_method to Auto."""
|
|
||||||
set_preview_method("auto")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.Auto
|
|
||||||
|
|
||||||
def test_override_with_none_value(self):
|
|
||||||
"""'none' should set args.preview_method to NoPreviews."""
|
|
||||||
set_preview_method("none")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
|
||||||
|
|
||||||
def test_default_restores_original(self):
|
|
||||||
"""'default' should restore to default_preview_method."""
|
|
||||||
# First override to something else
|
|
||||||
set_preview_method("taesd")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
# Then use 'default' to restore
|
|
||||||
set_preview_method("default")
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_none_param_restores_original(self):
|
|
||||||
"""None parameter should restore to default_preview_method."""
|
|
||||||
# First override to something else
|
|
||||||
set_preview_method("taesd")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
# Then use None to restore
|
|
||||||
set_preview_method(None)
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_empty_string_restores_original(self):
|
|
||||||
"""Empty string should restore to default_preview_method."""
|
|
||||||
set_preview_method("taesd")
|
|
||||||
set_preview_method("")
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_invalid_value_restores_original(self):
|
|
||||||
"""Invalid value should restore to default_preview_method."""
|
|
||||||
set_preview_method("taesd")
|
|
||||||
set_preview_method("invalid_method")
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_case_sensitive_invalid_restores(self):
|
|
||||||
"""Case-mismatched values should restore to default."""
|
|
||||||
set_preview_method("taesd")
|
|
||||||
set_preview_method("TAESD") # Wrong case
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
|
|
||||||
class TestDefaultPreviewMethod:
|
|
||||||
"""Test default_preview_method module variable."""
|
|
||||||
|
|
||||||
def test_default_is_not_none(self):
|
|
||||||
"""default_preview_method should not be None."""
|
|
||||||
assert default_preview_method is not None
|
|
||||||
|
|
||||||
def test_default_is_enum_member(self):
|
|
||||||
"""default_preview_method should be a LatentPreviewMethod enum."""
|
|
||||||
assert isinstance(default_preview_method, LatentPreviewMethod)
|
|
||||||
|
|
||||||
def test_default_matches_args_initial(self):
|
|
||||||
"""default_preview_method should match CLI default or user setting."""
|
|
||||||
# This tests that default_preview_method was captured at module load
|
|
||||||
# After set_preview_method(None), args should equal default
|
|
||||||
original = args.preview_method
|
|
||||||
set_preview_method("taesd")
|
|
||||||
set_preview_method(None)
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
args.preview_method = original
|
|
||||||
|
|
||||||
|
|
||||||
class TestArgsPreviewMethodModification:
|
|
||||||
"""Test args.preview_method can be modified correctly."""
|
|
||||||
|
|
||||||
def setup_method(self):
|
|
||||||
"""Store original value before each test."""
|
|
||||||
self.original = args.preview_method
|
|
||||||
|
|
||||||
def teardown_method(self):
|
|
||||||
"""Restore original value after each test."""
|
|
||||||
args.preview_method = self.original
|
|
||||||
|
|
||||||
def test_args_accepts_all_enum_values(self):
|
|
||||||
"""args.preview_method should accept all LatentPreviewMethod values."""
|
|
||||||
for method in LatentPreviewMethod:
|
|
||||||
args.preview_method = method
|
|
||||||
assert args.preview_method == method
|
|
||||||
|
|
||||||
def test_args_modification_and_restoration(self):
|
|
||||||
"""args.preview_method should be modifiable and restorable."""
|
|
||||||
original = args.preview_method
|
|
||||||
|
|
||||||
args.preview_method = LatentPreviewMethod.TAESD
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
args.preview_method = original
|
|
||||||
assert args.preview_method == original
|
|
||||||
|
|
||||||
|
|
||||||
class TestExecutionFlow:
|
|
||||||
"""Test the execution flow pattern used in execution.py."""
|
|
||||||
|
|
||||||
def setup_method(self):
|
|
||||||
"""Store original value before each test."""
|
|
||||||
self.original = args.preview_method
|
|
||||||
|
|
||||||
def teardown_method(self):
|
|
||||||
"""Restore original value after each test."""
|
|
||||||
args.preview_method = self.original
|
|
||||||
|
|
||||||
def test_sequential_executions_with_different_methods(self):
|
|
||||||
"""Simulate multiple queue executions with different preview methods."""
|
|
||||||
# Execution 1: taesd
|
|
||||||
set_preview_method("taesd")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
# Execution 2: none
|
|
||||||
set_preview_method("none")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
|
||||||
|
|
||||||
# Execution 3: default (restore)
|
|
||||||
set_preview_method("default")
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
# Execution 4: auto
|
|
||||||
set_preview_method("auto")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.Auto
|
|
||||||
|
|
||||||
# Execution 5: no override (None)
|
|
||||||
set_preview_method(None)
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_override_then_default_pattern(self):
|
|
||||||
"""Test the pattern: override -> execute -> next call restores."""
|
|
||||||
# First execution with override
|
|
||||||
set_preview_method("latent2rgb")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
|
||||||
|
|
||||||
# Second execution without override restores default
|
|
||||||
set_preview_method(None)
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_extra_data_simulation(self):
|
|
||||||
"""Simulate extra_data.get('preview_method') patterns."""
|
|
||||||
# Simulate: extra_data = {"preview_method": "taesd"}
|
|
||||||
extra_data = {"preview_method": "taesd"}
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
# Simulate: extra_data = {}
|
|
||||||
extra_data = {}
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
# Simulate: extra_data = {"preview_method": "default"}
|
|
||||||
extra_data = {"preview_method": "default"}
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
|
|
||||||
class TestRealWorldScenarios:
|
|
||||||
"""Tests using real-world prompt data patterns."""
|
|
||||||
|
|
||||||
def setup_method(self):
|
|
||||||
"""Store original value before each test."""
|
|
||||||
self.original = args.preview_method
|
|
||||||
|
|
||||||
def teardown_method(self):
|
|
||||||
"""Restore original value after each test."""
|
|
||||||
args.preview_method = self.original
|
|
||||||
|
|
||||||
def test_captured_prompt_without_preview_method(self):
|
|
||||||
"""
|
|
||||||
Test with captured prompt that has no preview_method.
|
|
||||||
Based on: tests-unit/execution_test/fixtures/default_prompt.json
|
|
||||||
"""
|
|
||||||
# Real captured extra_data structure (preview_method absent)
|
|
||||||
extra_data = {
|
|
||||||
"extra_pnginfo": {"workflow": {}},
|
|
||||||
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
|
||||||
"create_time": 1765416558179
|
|
||||||
}
|
|
||||||
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_captured_prompt_with_preview_method_taesd(self):
|
|
||||||
"""Test captured prompt with preview_method: taesd."""
|
|
||||||
extra_data = {
|
|
||||||
"extra_pnginfo": {"workflow": {}},
|
|
||||||
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
|
||||||
"preview_method": "taesd"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
def test_captured_prompt_with_preview_method_none(self):
|
|
||||||
"""Test captured prompt with preview_method: none (disable preview)."""
|
|
||||||
extra_data = {
|
|
||||||
"extra_pnginfo": {"workflow": {}},
|
|
||||||
"client_id": "test-client",
|
|
||||||
"preview_method": "none"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
|
||||||
|
|
||||||
def test_captured_prompt_with_preview_method_latent2rgb(self):
|
|
||||||
"""Test captured prompt with preview_method: latent2rgb."""
|
|
||||||
extra_data = {
|
|
||||||
"extra_pnginfo": {"workflow": {}},
|
|
||||||
"client_id": "test-client",
|
|
||||||
"preview_method": "latent2rgb"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
|
||||||
|
|
||||||
def test_captured_prompt_with_preview_method_auto(self):
|
|
||||||
"""Test captured prompt with preview_method: auto."""
|
|
||||||
extra_data = {
|
|
||||||
"extra_pnginfo": {"workflow": {}},
|
|
||||||
"client_id": "test-client",
|
|
||||||
"preview_method": "auto"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.Auto
|
|
||||||
|
|
||||||
def test_captured_prompt_with_preview_method_default(self):
|
|
||||||
"""Test captured prompt with preview_method: default (use CLI setting)."""
|
|
||||||
# First set to something else
|
|
||||||
set_preview_method("taesd")
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
# Then simulate a prompt with "default"
|
|
||||||
extra_data = {
|
|
||||||
"extra_pnginfo": {"workflow": {}},
|
|
||||||
"client_id": "test-client",
|
|
||||||
"preview_method": "default"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
def test_sequential_queue_with_different_preview_methods(self):
|
|
||||||
"""
|
|
||||||
Simulate real queue scenario: multiple prompts with different settings.
|
|
||||||
This tests the actual usage pattern in ComfyUI.
|
|
||||||
"""
|
|
||||||
# Queue 1: User wants TAESD preview
|
|
||||||
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
|
|
||||||
set_preview_method(extra_data_1.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
|
||||||
|
|
||||||
# Queue 2: User wants no preview (faster execution)
|
|
||||||
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
|
|
||||||
set_preview_method(extra_data_2.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
|
||||||
|
|
||||||
# Queue 3: User doesn't specify (use server default)
|
|
||||||
extra_data_3 = {"client_id": "client-3"}
|
|
||||||
set_preview_method(extra_data_3.get("preview_method"))
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
# Queue 4: User explicitly wants default
|
|
||||||
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
|
|
||||||
set_preview_method(extra_data_4.get("preview_method"))
|
|
||||||
assert args.preview_method == default_preview_method
|
|
||||||
|
|
||||||
# Queue 5: User wants latent2rgb
|
|
||||||
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
|
|
||||||
set_preview_method(extra_data_5.get("preview_method"))
|
|
||||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
|
||||||
@ -99,37 +99,6 @@ class ComfyClient:
|
|||||||
with urllib.request.urlopen(url) as response:
|
with urllib.request.urlopen(url) as response:
|
||||||
return json.loads(response.read())
|
return json.loads(response.read())
|
||||||
|
|
||||||
def get_jobs(self, status=None, limit=None, offset=None, sort_by=None, sort_order=None):
|
|
||||||
url = "http://{}/api/jobs".format(self.server_address)
|
|
||||||
params = {}
|
|
||||||
if status is not None:
|
|
||||||
params["status"] = status
|
|
||||||
if limit is not None:
|
|
||||||
params["limit"] = limit
|
|
||||||
if offset is not None:
|
|
||||||
params["offset"] = offset
|
|
||||||
if sort_by is not None:
|
|
||||||
params["sort_by"] = sort_by
|
|
||||||
if sort_order is not None:
|
|
||||||
params["sort_order"] = sort_order
|
|
||||||
|
|
||||||
if params:
|
|
||||||
url_values = urllib.parse.urlencode(params)
|
|
||||||
url = "{}?{}".format(url, url_values)
|
|
||||||
|
|
||||||
with urllib.request.urlopen(url) as response:
|
|
||||||
return json.loads(response.read())
|
|
||||||
|
|
||||||
def get_job(self, job_id):
|
|
||||||
url = "http://{}/api/jobs/{}".format(self.server_address, job_id)
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(url) as response:
|
|
||||||
return json.loads(response.read())
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
if e.code == 404:
|
|
||||||
return None
|
|
||||||
raise
|
|
||||||
|
|
||||||
def set_test_name(self, name):
|
def set_test_name(self, name):
|
||||||
self.test_name = name
|
self.test_name = name
|
||||||
|
|
||||||
@ -908,106 +877,3 @@ class TestExecution:
|
|||||||
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
|
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
|
||||||
|
|
||||||
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
||||||
|
|
||||||
# Jobs API tests
|
|
||||||
def test_jobs_api_job_structure(
|
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
|
||||||
):
|
|
||||||
"""Test that job objects have required fields"""
|
|
||||||
self._create_history_item(client, builder)
|
|
||||||
|
|
||||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
|
||||||
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
|
|
||||||
|
|
||||||
job = jobs_response["jobs"][0]
|
|
||||||
assert "id" in job, "Job should have id"
|
|
||||||
assert "status" in job, "Job should have status"
|
|
||||||
assert "create_time" in job, "Job should have create_time"
|
|
||||||
assert "outputs_count" in job, "Job should have outputs_count"
|
|
||||||
assert "preview_output" in job, "Job should have preview_output"
|
|
||||||
|
|
||||||
def test_jobs_api_preview_output_structure(
|
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
|
||||||
):
|
|
||||||
"""Test that preview_output has correct structure"""
|
|
||||||
self._create_history_item(client, builder)
|
|
||||||
|
|
||||||
jobs_response = client.get_jobs(status="completed", limit=1)
|
|
||||||
job = jobs_response["jobs"][0]
|
|
||||||
|
|
||||||
if job["preview_output"] is not None:
|
|
||||||
preview = job["preview_output"]
|
|
||||||
assert "filename" in preview, "Preview should have filename"
|
|
||||||
assert "nodeId" in preview, "Preview should have nodeId"
|
|
||||||
assert "mediaType" in preview, "Preview should have mediaType"
|
|
||||||
|
|
||||||
def test_jobs_api_pagination(
|
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
|
||||||
):
|
|
||||||
"""Test jobs API pagination"""
|
|
||||||
for _ in range(5):
|
|
||||||
self._create_history_item(client, builder)
|
|
||||||
|
|
||||||
first_page = client.get_jobs(limit=2, offset=0)
|
|
||||||
second_page = client.get_jobs(limit=2, offset=2)
|
|
||||||
|
|
||||||
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
|
|
||||||
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
|
|
||||||
|
|
||||||
first_ids = {j["id"] for j in first_page["jobs"]}
|
|
||||||
second_ids = {j["id"] for j in second_page["jobs"]}
|
|
||||||
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
|
|
||||||
|
|
||||||
def test_jobs_api_sorting(
|
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
|
||||||
):
|
|
||||||
"""Test jobs API sorting"""
|
|
||||||
for _ in range(3):
|
|
||||||
self._create_history_item(client, builder)
|
|
||||||
|
|
||||||
desc_jobs = client.get_jobs(sort_order="desc")
|
|
||||||
asc_jobs = client.get_jobs(sort_order="asc")
|
|
||||||
|
|
||||||
if len(desc_jobs["jobs"]) >= 2:
|
|
||||||
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
|
|
||||||
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
|
|
||||||
if len(desc_times) >= 2:
|
|
||||||
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
|
|
||||||
if len(asc_times) >= 2:
|
|
||||||
assert asc_times == sorted(asc_times), "Asc should be oldest first"
|
|
||||||
|
|
||||||
def test_jobs_api_status_filter(
|
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
|
||||||
):
|
|
||||||
"""Test jobs API status filtering"""
|
|
||||||
self._create_history_item(client, builder)
|
|
||||||
|
|
||||||
completed_jobs = client.get_jobs(status="completed")
|
|
||||||
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
|
|
||||||
|
|
||||||
for job in completed_jobs["jobs"]:
|
|
||||||
assert job["status"] == "completed", "Should only return completed jobs"
|
|
||||||
|
|
||||||
# Pending jobs are transient - just verify filter doesn't error
|
|
||||||
pending_jobs = client.get_jobs(status="pending")
|
|
||||||
for job in pending_jobs["jobs"]:
|
|
||||||
assert job["status"] == "pending", "Should only return pending jobs"
|
|
||||||
|
|
||||||
def test_get_job_by_id(
|
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
|
||||||
):
|
|
||||||
"""Test getting a single job by ID"""
|
|
||||||
result = self._create_history_item(client, builder)
|
|
||||||
prompt_id = result.get_prompt_id()
|
|
||||||
|
|
||||||
job = client.get_job(prompt_id)
|
|
||||||
assert job is not None, "Should find the job"
|
|
||||||
assert job["id"] == prompt_id, "Job ID should match"
|
|
||||||
assert "outputs" in job, "Single job should include outputs"
|
|
||||||
|
|
||||||
def test_get_job_not_found(
|
|
||||||
self, client: ComfyClient, builder: GraphBuilder
|
|
||||||
):
|
|
||||||
"""Test getting a non-existent job returns 404"""
|
|
||||||
job = client.get_job("nonexistent-job-id")
|
|
||||||
assert job is None, "Non-existent job should return None"
|
|
||||||
|
|||||||
@ -1,361 +0,0 @@
|
|||||||
"""Unit tests for comfy_execution/jobs.py"""
|
|
||||||
|
|
||||||
from comfy_execution.jobs import (
|
|
||||||
JobStatus,
|
|
||||||
is_previewable,
|
|
||||||
normalize_queue_item,
|
|
||||||
normalize_history_item,
|
|
||||||
get_outputs_summary,
|
|
||||||
apply_sorting,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestJobStatus:
|
|
||||||
"""Test JobStatus constants."""
|
|
||||||
|
|
||||||
def test_status_values(self):
|
|
||||||
"""Status constants should have expected string values."""
|
|
||||||
assert JobStatus.PENDING == 'pending'
|
|
||||||
assert JobStatus.IN_PROGRESS == 'in_progress'
|
|
||||||
assert JobStatus.COMPLETED == 'completed'
|
|
||||||
assert JobStatus.FAILED == 'failed'
|
|
||||||
|
|
||||||
def test_all_contains_all_statuses(self):
|
|
||||||
"""ALL should contain all status values."""
|
|
||||||
assert JobStatus.PENDING in JobStatus.ALL
|
|
||||||
assert JobStatus.IN_PROGRESS in JobStatus.ALL
|
|
||||||
assert JobStatus.COMPLETED in JobStatus.ALL
|
|
||||||
assert JobStatus.FAILED in JobStatus.ALL
|
|
||||||
assert len(JobStatus.ALL) == 4
|
|
||||||
|
|
||||||
|
|
||||||
class TestIsPreviewable:
|
|
||||||
"""Unit tests for is_previewable()"""
|
|
||||||
|
|
||||||
def test_previewable_media_types(self):
|
|
||||||
"""Images, video, audio media types should be previewable."""
|
|
||||||
for media_type in ['images', 'video', 'audio']:
|
|
||||||
assert is_previewable(media_type, {}) is True
|
|
||||||
|
|
||||||
def test_non_previewable_media_types(self):
|
|
||||||
"""Other media types should not be previewable."""
|
|
||||||
for media_type in ['latents', 'text', 'metadata', 'files']:
|
|
||||||
assert is_previewable(media_type, {}) is False
|
|
||||||
|
|
||||||
def test_3d_extensions_previewable(self):
|
|
||||||
"""3D file extensions should be previewable regardless of media_type."""
|
|
||||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
|
||||||
item = {'filename': f'model{ext}'}
|
|
||||||
assert is_previewable('files', item) is True
|
|
||||||
|
|
||||||
def test_3d_extensions_case_insensitive(self):
|
|
||||||
"""3D extension check should be case insensitive."""
|
|
||||||
item = {'filename': 'MODEL.GLB'}
|
|
||||||
assert is_previewable('files', item) is True
|
|
||||||
|
|
||||||
def test_video_format_previewable(self):
|
|
||||||
"""Items with video/ format should be previewable."""
|
|
||||||
item = {'format': 'video/mp4'}
|
|
||||||
assert is_previewable('files', item) is True
|
|
||||||
|
|
||||||
def test_audio_format_previewable(self):
|
|
||||||
"""Items with audio/ format should be previewable."""
|
|
||||||
item = {'format': 'audio/wav'}
|
|
||||||
assert is_previewable('files', item) is True
|
|
||||||
|
|
||||||
def test_other_format_not_previewable(self):
|
|
||||||
"""Items with other format should not be previewable."""
|
|
||||||
item = {'format': 'application/json'}
|
|
||||||
assert is_previewable('files', item) is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetOutputsSummary:
|
|
||||||
"""Unit tests for get_outputs_summary()"""
|
|
||||||
|
|
||||||
def test_empty_outputs(self):
|
|
||||||
"""Empty outputs should return 0 count and None preview."""
|
|
||||||
count, preview = get_outputs_summary({})
|
|
||||||
assert count == 0
|
|
||||||
assert preview is None
|
|
||||||
|
|
||||||
def test_counts_across_multiple_nodes(self):
|
|
||||||
"""Outputs from multiple nodes should all be counted."""
|
|
||||||
outputs = {
|
|
||||||
'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]},
|
|
||||||
'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]},
|
|
||||||
'node3': {'images': [
|
|
||||||
{'filename': 'c.png', 'type': 'output'},
|
|
||||||
{'filename': 'd.png', 'type': 'output'}
|
|
||||||
]}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert count == 4
|
|
||||||
|
|
||||||
def test_skips_animated_key_and_non_list_values(self):
|
|
||||||
"""The 'animated' key and non-list values should be skipped."""
|
|
||||||
outputs = {
|
|
||||||
'node1': {
|
|
||||||
'images': [{'filename': 'test.png', 'type': 'output'}],
|
|
||||||
'animated': [True], # Should skip due to key name
|
|
||||||
'metadata': 'string', # Should skip due to non-list
|
|
||||||
'count': 42 # Should skip due to non-list
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert count == 1
|
|
||||||
|
|
||||||
def test_preview_prefers_type_output(self):
|
|
||||||
"""Items with type='output' should be preferred for preview."""
|
|
||||||
outputs = {
|
|
||||||
'node1': {
|
|
||||||
'images': [
|
|
||||||
{'filename': 'temp.png', 'type': 'temp'},
|
|
||||||
{'filename': 'output.png', 'type': 'output'}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert count == 2
|
|
||||||
assert preview['filename'] == 'output.png'
|
|
||||||
|
|
||||||
def test_preview_fallback_when_no_output_type(self):
|
|
||||||
"""If no type='output', should use first previewable."""
|
|
||||||
outputs = {
|
|
||||||
'node1': {
|
|
||||||
'images': [
|
|
||||||
{'filename': 'temp1.png', 'type': 'temp'},
|
|
||||||
{'filename': 'temp2.png', 'type': 'temp'}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert preview['filename'] == 'temp1.png'
|
|
||||||
|
|
||||||
def test_non_previewable_media_types_counted_but_no_preview(self):
|
|
||||||
"""Non-previewable media types should be counted but not used as preview."""
|
|
||||||
outputs = {
|
|
||||||
'node1': {
|
|
||||||
'latents': [
|
|
||||||
{'filename': 'latent1.safetensors'},
|
|
||||||
{'filename': 'latent2.safetensors'}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert count == 2
|
|
||||||
assert preview is None
|
|
||||||
|
|
||||||
def test_previewable_media_types(self):
|
|
||||||
"""Images, video, and audio media types should be previewable."""
|
|
||||||
for media_type in ['images', 'video', 'audio']:
|
|
||||||
outputs = {
|
|
||||||
'node1': {
|
|
||||||
media_type: [{'filename': 'test.file', 'type': 'output'}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert preview is not None, f"{media_type} should be previewable"
|
|
||||||
|
|
||||||
def test_3d_files_previewable(self):
|
|
||||||
"""3D file extensions should be previewable."""
|
|
||||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
|
||||||
outputs = {
|
|
||||||
'node1': {
|
|
||||||
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert preview is not None, f"3D file {ext} should be previewable"
|
|
||||||
|
|
||||||
def test_format_mime_type_previewable(self):
|
|
||||||
"""Files with video/ or audio/ format should be previewable."""
|
|
||||||
for fmt in ['video/x-custom', 'audio/x-custom']:
|
|
||||||
outputs = {
|
|
||||||
'node1': {
|
|
||||||
'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert preview is not None, f"Format {fmt} should be previewable"
|
|
||||||
|
|
||||||
def test_preview_enriched_with_node_metadata(self):
|
|
||||||
"""Preview should include nodeId, mediaType, and original fields."""
|
|
||||||
outputs = {
|
|
||||||
'node123': {
|
|
||||||
'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count, preview = get_outputs_summary(outputs)
|
|
||||||
assert preview['nodeId'] == 'node123'
|
|
||||||
assert preview['mediaType'] == 'images'
|
|
||||||
assert preview['subfolder'] == 'outputs'
|
|
||||||
|
|
||||||
|
|
||||||
class TestApplySorting:
|
|
||||||
"""Unit tests for apply_sorting()"""
|
|
||||||
|
|
||||||
def test_sort_by_create_time_desc(self):
|
|
||||||
"""Default sort by create_time descending."""
|
|
||||||
jobs = [
|
|
||||||
{'id': 'a', 'create_time': 100},
|
|
||||||
{'id': 'b', 'create_time': 300},
|
|
||||||
{'id': 'c', 'create_time': 200},
|
|
||||||
]
|
|
||||||
result = apply_sorting(jobs, 'created_at', 'desc')
|
|
||||||
assert [j['id'] for j in result] == ['b', 'c', 'a']
|
|
||||||
|
|
||||||
def test_sort_by_create_time_asc(self):
|
|
||||||
"""Sort by create_time ascending."""
|
|
||||||
jobs = [
|
|
||||||
{'id': 'a', 'create_time': 100},
|
|
||||||
{'id': 'b', 'create_time': 300},
|
|
||||||
{'id': 'c', 'create_time': 200},
|
|
||||||
]
|
|
||||||
result = apply_sorting(jobs, 'created_at', 'asc')
|
|
||||||
assert [j['id'] for j in result] == ['a', 'c', 'b']
|
|
||||||
|
|
||||||
def test_sort_by_execution_duration(self):
|
|
||||||
"""Sort by execution_duration should order by duration."""
|
|
||||||
jobs = [
|
|
||||||
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s
|
|
||||||
{'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s
|
|
||||||
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s
|
|
||||||
]
|
|
||||||
result = apply_sorting(jobs, 'execution_duration', 'desc')
|
|
||||||
assert [j['id'] for j in result] == ['a', 'c', 'b']
|
|
||||||
|
|
||||||
def test_sort_with_none_values(self):
|
|
||||||
"""Jobs with None values should sort as 0."""
|
|
||||||
jobs = [
|
|
||||||
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100},
|
|
||||||
{'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None},
|
|
||||||
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200},
|
|
||||||
]
|
|
||||||
result = apply_sorting(jobs, 'execution_duration', 'asc')
|
|
||||||
assert result[0]['id'] == 'b' # None treated as 0, comes first
|
|
||||||
|
|
||||||
|
|
||||||
class TestNormalizeQueueItem:
|
|
||||||
"""Unit tests for normalize_queue_item()"""
|
|
||||||
|
|
||||||
def test_basic_normalization(self):
|
|
||||||
"""Queue item should be normalized to job dict."""
|
|
||||||
item = (
|
|
||||||
10, # priority/number
|
|
||||||
'prompt-123', # prompt_id
|
|
||||||
{'nodes': {}}, # prompt
|
|
||||||
{
|
|
||||||
'create_time': 1234567890,
|
|
||||||
'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}}
|
|
||||||
}, # extra_data
|
|
||||||
['node1'], # outputs_to_execute
|
|
||||||
)
|
|
||||||
job = normalize_queue_item(item, JobStatus.PENDING)
|
|
||||||
|
|
||||||
assert job['id'] == 'prompt-123'
|
|
||||||
assert job['status'] == 'pending'
|
|
||||||
assert job['priority'] == 10
|
|
||||||
assert job['create_time'] == 1234567890
|
|
||||||
assert 'execution_start_time' not in job
|
|
||||||
assert 'execution_end_time' not in job
|
|
||||||
assert 'execution_error' not in job
|
|
||||||
assert 'preview_output' not in job
|
|
||||||
assert job['outputs_count'] == 0
|
|
||||||
assert job['workflow_id'] == 'workflow-abc'
|
|
||||||
|
|
||||||
|
|
||||||
class TestNormalizeHistoryItem:
|
|
||||||
"""Unit tests for normalize_history_item()"""
|
|
||||||
|
|
||||||
def test_completed_job(self):
|
|
||||||
"""Completed history item should have correct status and times from messages."""
|
|
||||||
history_item = {
|
|
||||||
'prompt': (
|
|
||||||
5, # priority
|
|
||||||
'prompt-456',
|
|
||||||
{'nodes': {}},
|
|
||||||
{
|
|
||||||
'create_time': 1234567890000,
|
|
||||||
'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}}
|
|
||||||
},
|
|
||||||
['node1'],
|
|
||||||
),
|
|
||||||
'status': {
|
|
||||||
'status_str': 'success',
|
|
||||||
'completed': True,
|
|
||||||
'messages': [
|
|
||||||
('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}),
|
|
||||||
('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}),
|
|
||||||
]
|
|
||||||
},
|
|
||||||
'outputs': {},
|
|
||||||
}
|
|
||||||
job = normalize_history_item('prompt-456', history_item)
|
|
||||||
|
|
||||||
assert job['id'] == 'prompt-456'
|
|
||||||
assert job['status'] == 'completed'
|
|
||||||
assert job['priority'] == 5
|
|
||||||
assert job['execution_start_time'] == 1234567890500
|
|
||||||
assert job['execution_end_time'] == 1234567893000
|
|
||||||
assert job['workflow_id'] == 'workflow-xyz'
|
|
||||||
|
|
||||||
def test_failed_job(self):
|
|
||||||
"""Failed history item should have failed status and error from messages."""
|
|
||||||
history_item = {
|
|
||||||
'prompt': (
|
|
||||||
5,
|
|
||||||
'prompt-789',
|
|
||||||
{'nodes': {}},
|
|
||||||
{'create_time': 1234567890000},
|
|
||||||
['node1'],
|
|
||||||
),
|
|
||||||
'status': {
|
|
||||||
'status_str': 'error',
|
|
||||||
'completed': False,
|
|
||||||
'messages': [
|
|
||||||
('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}),
|
|
||||||
('execution_error', {
|
|
||||||
'prompt_id': 'prompt-789',
|
|
||||||
'node_id': '5',
|
|
||||||
'node_type': 'KSampler',
|
|
||||||
'exception_message': 'CUDA out of memory',
|
|
||||||
'exception_type': 'RuntimeError',
|
|
||||||
'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'],
|
|
||||||
'timestamp': 1234567891000,
|
|
||||||
})
|
|
||||||
]
|
|
||||||
},
|
|
||||||
'outputs': {},
|
|
||||||
}
|
|
||||||
|
|
||||||
job = normalize_history_item('prompt-789', history_item)
|
|
||||||
assert job['status'] == 'failed'
|
|
||||||
assert job['execution_start_time'] == 1234567890500
|
|
||||||
assert job['execution_end_time'] == 1234567891000
|
|
||||||
assert job['execution_error']['node_id'] == '5'
|
|
||||||
assert job['execution_error']['node_type'] == 'KSampler'
|
|
||||||
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
|
|
||||||
|
|
||||||
def test_include_outputs(self):
|
|
||||||
"""When include_outputs=True, should include full output data."""
|
|
||||||
history_item = {
|
|
||||||
'prompt': (
|
|
||||||
5,
|
|
||||||
'prompt-123',
|
|
||||||
{'nodes': {'1': {}}},
|
|
||||||
{'create_time': 1234567890, 'client_id': 'abc'},
|
|
||||||
['node1'],
|
|
||||||
),
|
|
||||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
|
||||||
'outputs': {'node1': {'images': [{'filename': 'test.png'}]}},
|
|
||||||
}
|
|
||||||
job = normalize_history_item('prompt-123', history_item, include_outputs=True)
|
|
||||||
|
|
||||||
assert 'outputs' in job
|
|
||||||
assert 'workflow' in job
|
|
||||||
assert 'execution_status' in job
|
|
||||||
assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}}
|
|
||||||
assert job['workflow'] == {
|
|
||||||
'prompt': {'nodes': {'1': {}}},
|
|
||||||
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
|
||||||
}
|
|
||||||
@ -1,358 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for Queue-specific Preview Method Override feature.
|
|
||||||
|
|
||||||
Tests actual execution with different preview_method values.
|
|
||||||
Requires a running ComfyUI server with models.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
|
|
||||||
|
|
||||||
Note:
|
|
||||||
These tests execute actual image generation and wait for completion.
|
|
||||||
Tests verify preview image transmission based on preview_method setting.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import websocket
|
|
||||||
import urllib.request
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
# Server configuration
|
|
||||||
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
|
|
||||||
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
|
|
||||||
|
|
||||||
# Use existing inference graph fixture
|
|
||||||
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
|
|
||||||
|
|
||||||
|
|
||||||
def is_server_running() -> bool:
|
|
||||||
"""Check if ComfyUI server is running."""
|
|
||||||
try:
|
|
||||||
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
|
|
||||||
with urllib.request.urlopen(request, timeout=2.0):
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
|
|
||||||
"""Prepare graph for testing: randomize seeds and reduce steps."""
|
|
||||||
adapted = json.loads(json.dumps(graph)) # Deep copy
|
|
||||||
for node_id, node in adapted.items():
|
|
||||||
inputs = node.get("inputs", {})
|
|
||||||
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
|
|
||||||
if "seed" in inputs:
|
|
||||||
inputs["seed"] = random.randint(0, 2**32 - 1)
|
|
||||||
if "noise_seed" in inputs:
|
|
||||||
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
|
|
||||||
# Reduce steps for faster testing (default 20 -> 5)
|
|
||||||
if "steps" in inputs:
|
|
||||||
inputs["steps"] = steps
|
|
||||||
return adapted
|
|
||||||
|
|
||||||
|
|
||||||
# Alias for backward compatibility
|
|
||||||
randomize_seed = prepare_graph_for_test
|
|
||||||
|
|
||||||
|
|
||||||
class PreviewMethodClient:
|
|
||||||
"""Client for testing preview_method with WebSocket execution tracking."""
|
|
||||||
|
|
||||||
def __init__(self, server_address: str):
|
|
||||||
self.server_address = server_address
|
|
||||||
self.client_id = str(uuid.uuid4())
|
|
||||||
self.ws = None
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
"""Connect to WebSocket."""
|
|
||||||
self.ws = websocket.WebSocket()
|
|
||||||
self.ws.settimeout(120) # 2 minute timeout for sampling
|
|
||||||
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Close WebSocket connection."""
|
|
||||||
if self.ws:
|
|
||||||
self.ws.close()
|
|
||||||
|
|
||||||
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
|
|
||||||
"""Queue a prompt and return response with prompt_id."""
|
|
||||||
data = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"extra_data": extra_data or {}
|
|
||||||
}
|
|
||||||
req = urllib.request.Request(
|
|
||||||
f"http://{self.server_address}/prompt",
|
|
||||||
data=json.dumps(data).encode("utf-8"),
|
|
||||||
headers={"Content-Type": "application/json"}
|
|
||||||
)
|
|
||||||
return json.loads(urllib.request.urlopen(req).read())
|
|
||||||
|
|
||||||
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
|
|
||||||
"""
|
|
||||||
Wait for execution to complete via WebSocket.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict with keys: completed, error, preview_count, execution_time
|
|
||||||
"""
|
|
||||||
result = {
|
|
||||||
"completed": False,
|
|
||||||
"error": None,
|
|
||||||
"preview_count": 0,
|
|
||||||
"execution_time": 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
self.ws.settimeout(timeout)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
out = self.ws.recv()
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
|
|
||||||
if isinstance(out, str):
|
|
||||||
message = json.loads(out)
|
|
||||||
msg_type = message.get("type")
|
|
||||||
data = message.get("data", {})
|
|
||||||
|
|
||||||
if data.get("prompt_id") != prompt_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if msg_type == "executing":
|
|
||||||
if data.get("node") is None:
|
|
||||||
# Execution complete
|
|
||||||
result["completed"] = True
|
|
||||||
result["execution_time"] = elapsed
|
|
||||||
break
|
|
||||||
|
|
||||||
elif msg_type == "execution_error":
|
|
||||||
result["error"] = data
|
|
||||||
result["execution_time"] = elapsed
|
|
||||||
break
|
|
||||||
|
|
||||||
elif msg_type == "progress":
|
|
||||||
# Progress update during sampling
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif isinstance(out, bytes):
|
|
||||||
# Binary data = preview image
|
|
||||||
result["preview_count"] += 1
|
|
||||||
|
|
||||||
except websocket.WebSocketTimeoutException:
|
|
||||||
result["error"] = "Timeout waiting for execution"
|
|
||||||
result["execution_time"] = time.time() - start_time
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def load_graph() -> dict:
|
|
||||||
"""Load the SDXL graph fixture with randomized seed."""
|
|
||||||
with open(GRAPH_FILE) as f:
|
|
||||||
graph = json.load(f)
|
|
||||||
return randomize_seed(graph) # Avoid caching
|
|
||||||
|
|
||||||
|
|
||||||
# Skip all tests if server is not running
|
|
||||||
pytestmark = [
|
|
||||||
pytest.mark.skipif(
|
|
||||||
not is_server_running(),
|
|
||||||
reason=f"ComfyUI server not running at {SERVER_URL}"
|
|
||||||
),
|
|
||||||
pytest.mark.preview_method,
|
|
||||||
pytest.mark.execution,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
"""Create and connect a test client."""
|
|
||||||
c = PreviewMethodClient(SERVER_HOST)
|
|
||||||
c.connect()
|
|
||||||
yield c
|
|
||||||
c.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def graph():
|
|
||||||
"""Load the test graph."""
|
|
||||||
return load_graph()
|
|
||||||
|
|
||||||
|
|
||||||
class TestPreviewMethodExecution:
|
|
||||||
"""Test actual execution with different preview methods."""
|
|
||||||
|
|
||||||
def test_execution_with_latent2rgb(self, client, graph):
|
|
||||||
"""
|
|
||||||
Execute with preview_method=latent2rgb.
|
|
||||||
Should complete and potentially receive preview images.
|
|
||||||
"""
|
|
||||||
extra_data = {"preview_method": "latent2rgb"}
|
|
||||||
|
|
||||||
response = client.queue_prompt(graph, extra_data)
|
|
||||||
assert "prompt_id" in response
|
|
||||||
|
|
||||||
result = client.wait_for_execution(response["prompt_id"])
|
|
||||||
|
|
||||||
# Should complete (may error if model missing, but that's separate)
|
|
||||||
assert result["completed"] or result["error"] is not None
|
|
||||||
# Execution should take some time (sampling)
|
|
||||||
if result["completed"]:
|
|
||||||
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
|
|
||||||
# latent2rgb should produce previews
|
|
||||||
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
||||||
|
|
||||||
def test_execution_with_taesd(self, client, graph):
|
|
||||||
"""
|
|
||||||
Execute with preview_method=taesd.
|
|
||||||
TAESD provides higher quality previews.
|
|
||||||
"""
|
|
||||||
extra_data = {"preview_method": "taesd"}
|
|
||||||
|
|
||||||
response = client.queue_prompt(graph, extra_data)
|
|
||||||
assert "prompt_id" in response
|
|
||||||
|
|
||||||
result = client.wait_for_execution(response["prompt_id"])
|
|
||||||
|
|
||||||
assert result["completed"] or result["error"] is not None
|
|
||||||
if result["completed"]:
|
|
||||||
assert result["execution_time"] > 0.5
|
|
||||||
# taesd should also produce previews
|
|
||||||
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
||||||
|
|
||||||
def test_execution_with_none_preview(self, client, graph):
|
|
||||||
"""
|
|
||||||
Execute with preview_method=none.
|
|
||||||
No preview images should be generated.
|
|
||||||
"""
|
|
||||||
extra_data = {"preview_method": "none"}
|
|
||||||
|
|
||||||
response = client.queue_prompt(graph, extra_data)
|
|
||||||
assert "prompt_id" in response
|
|
||||||
|
|
||||||
result = client.wait_for_execution(response["prompt_id"])
|
|
||||||
|
|
||||||
assert result["completed"] or result["error"] is not None
|
|
||||||
if result["completed"]:
|
|
||||||
# With "none", should receive no preview images
|
|
||||||
assert result["preview_count"] == 0, \
|
|
||||||
f"Expected no previews with 'none', got {result['preview_count']}"
|
|
||||||
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
||||||
|
|
||||||
def test_execution_with_default(self, client, graph):
|
|
||||||
"""
|
|
||||||
Execute with preview_method=default.
|
|
||||||
Should use server's CLI default setting.
|
|
||||||
"""
|
|
||||||
extra_data = {"preview_method": "default"}
|
|
||||||
|
|
||||||
response = client.queue_prompt(graph, extra_data)
|
|
||||||
assert "prompt_id" in response
|
|
||||||
|
|
||||||
result = client.wait_for_execution(response["prompt_id"])
|
|
||||||
|
|
||||||
assert result["completed"] or result["error"] is not None
|
|
||||||
if result["completed"]:
|
|
||||||
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
||||||
|
|
||||||
def test_execution_without_preview_method(self, client, graph):
|
|
||||||
"""
|
|
||||||
Execute without preview_method in extra_data.
|
|
||||||
Should use server's default preview method.
|
|
||||||
"""
|
|
||||||
extra_data = {} # No preview_method
|
|
||||||
|
|
||||||
response = client.queue_prompt(graph, extra_data)
|
|
||||||
assert "prompt_id" in response
|
|
||||||
|
|
||||||
result = client.wait_for_execution(response["prompt_id"])
|
|
||||||
|
|
||||||
assert result["completed"] or result["error"] is not None
|
|
||||||
if result["completed"]:
|
|
||||||
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
|
||||||
|
|
||||||
|
|
||||||
class TestPreviewMethodComparison:
|
|
||||||
"""Compare preview behavior between different methods."""
|
|
||||||
|
|
||||||
def test_none_vs_latent2rgb_preview_count(self, client, graph):
|
|
||||||
"""
|
|
||||||
Compare preview counts: 'none' should have 0, others should have >0.
|
|
||||||
This is the key verification that preview_method actually works.
|
|
||||||
"""
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
# Run with none (randomize seed to avoid caching)
|
|
||||||
graph_none = randomize_seed(graph)
|
|
||||||
extra_data_none = {"preview_method": "none"}
|
|
||||||
response = client.queue_prompt(graph_none, extra_data_none)
|
|
||||||
results["none"] = client.wait_for_execution(response["prompt_id"])
|
|
||||||
|
|
||||||
# Run with latent2rgb (randomize seed again)
|
|
||||||
graph_rgb = randomize_seed(graph)
|
|
||||||
extra_data_rgb = {"preview_method": "latent2rgb"}
|
|
||||||
response = client.queue_prompt(graph_rgb, extra_data_rgb)
|
|
||||||
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
|
|
||||||
|
|
||||||
# Verify both completed
|
|
||||||
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
|
|
||||||
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
|
|
||||||
|
|
||||||
# Key assertion: 'none' should have 0 previews
|
|
||||||
assert results["none"]["preview_count"] == 0, \
|
|
||||||
f"'none' should have 0 previews, got {results['none']['preview_count']}"
|
|
||||||
|
|
||||||
# 'latent2rgb' should have at least 1 preview (depends on steps)
|
|
||||||
assert results["latent2rgb"]["preview_count"] > 0, \
|
|
||||||
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
|
|
||||||
|
|
||||||
print("\nPreview count comparison:") # noqa: T201
|
|
||||||
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
|
|
||||||
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
|
|
||||||
|
|
||||||
|
|
||||||
class TestPreviewMethodSequential:
|
|
||||||
"""Test sequential execution with different preview methods."""
|
|
||||||
|
|
||||||
def test_sequential_different_methods(self, client, graph):
|
|
||||||
"""
|
|
||||||
Execute multiple prompts sequentially with different preview methods.
|
|
||||||
Each should complete independently with correct preview behavior.
|
|
||||||
"""
|
|
||||||
methods = ["latent2rgb", "none", "default"]
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for method in methods:
|
|
||||||
# Randomize seed for each execution to avoid caching
|
|
||||||
graph_run = randomize_seed(graph)
|
|
||||||
extra_data = {"preview_method": method}
|
|
||||||
response = client.queue_prompt(graph_run, extra_data)
|
|
||||||
|
|
||||||
result = client.wait_for_execution(response["prompt_id"])
|
|
||||||
results.append({
|
|
||||||
"method": method,
|
|
||||||
"completed": result["completed"],
|
|
||||||
"preview_count": result["preview_count"],
|
|
||||||
"execution_time": result["execution_time"],
|
|
||||||
"error": result["error"]
|
|
||||||
})
|
|
||||||
|
|
||||||
# All should complete or have clear errors
|
|
||||||
for r in results:
|
|
||||||
assert r["completed"] or r["error"] is not None, \
|
|
||||||
f"Method {r['method']} neither completed nor errored"
|
|
||||||
|
|
||||||
# "none" should have zero previews if completed
|
|
||||||
none_result = next(r for r in results if r["method"] == "none")
|
|
||||||
if none_result["completed"]:
|
|
||||||
assert none_result["preview_count"] == 0, \
|
|
||||||
f"'none' should have 0 previews, got {none_result['preview_count']}"
|
|
||||||
|
|
||||||
print("\nSequential execution results:") # noqa: T201
|
|
||||||
for r in results:
|
|
||||||
status = "✓" if r["completed"] else f"✗ ({r['error']})"
|
|
||||||
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201
|
|
||||||
Loading…
Reference in New Issue
Block a user