Merge branch 'master' into offload_to_mmap

This commit is contained in:
Xiaoyu Xu 2025-12-18 16:19:16 +08:00 committed by GitHub
commit 8b433f294a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
81 changed files with 4353 additions and 1614 deletions

View File

@ -53,6 +53,16 @@ try:
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
except:
print("Could not stash, cleaning index and trying again.") # noqa: T201
repo.state_cleanup()
repo.index.read_tree(repo.head.peel().tree)
repo.index.write()
try:
repo.stash(ident)
except KeyError:
print("nothing to stash.") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:

View File

@ -5,6 +5,7 @@ on:
push:
branches:
- master
- release/**
paths-ignore:
- 'app/**'
- 'input/**'

View File

@ -2,9 +2,9 @@ name: Execution Tests
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Test server launches without errors
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Unit Tests
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -6,6 +6,7 @@ on:
- "pyproject.toml"
branches:
- master
- release/**
jobs:
update-version:

View File

@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
def is_visible_file(entry: os.DirEntry) -> bool:
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
return entry.is_file() and not entry.name.startswith('.')
sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()),
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)

View File

@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb"
TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

View File

@ -87,6 +87,7 @@ class IndexListCallbacks:
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self):
return {}
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
@ -1609,6 +1618,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
@torch.no_grad()
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
@ -1756,7 +1776,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
x_pred = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
@ -1777,7 +1797,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x
return x_pred
@torch.no_grad()

View File

@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
class ChromaRadiance(Chroma):
"""
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
self.skip_dit = []
self.lite = False
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
params_dict |= overrides
return params.__class__(**params_dict)
def _apply_x0_residual(self, predicted, noisy, timesteps):
# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
def _forward(
self,
x: Tensor,
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out

View File

@ -43,6 +43,7 @@ class HunyuanVideoParams:
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
meanflow_sum: bool
class SelfAttentionRef(nn.Module):
@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module):
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)

View File

@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
n_control_layers=6,
control_in_dim=16,
additional_in_dim=0,
broken=False,
refiner_control=False,
dtype=None,
device=None,
operations=None,
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.additional_in_dim = 0
self.control_in_dim = 16
self.broken = broken
self.additional_in_dim = additional_in_dim
self.control_in_dim = control_in_dim
n_refiner_layers = 2
self.n_control_layers = 6
self.n_control_layers = n_control_layers
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.refiner_control = refiner_control
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
if self.refiner_control:
self.control_noise_refiner = nn.ModuleList(
[
ZImageControlTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=layer_id,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
else:
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
if not self.refiner_control:
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
if self.refiner_control:
if self.broken:
if layer_id == 0:
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if layer_id > 0:
out = None
for i in range(1, len(self.control_layers)):
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if out is None:
out = o
return (out, control_context)
else:
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
else:
return (None, control_context)
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -536,6 +536,7 @@ class NextDiT(nn.Module):
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@ -572,13 +573,21 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
patches = transformer_options.get("patches", {})
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None
for layer in self.noise_refiner:
x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
@ -622,14 +631,18 @@ class NextDiT(nn.Module):
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:

View File

@ -218,9 +218,24 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -302,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module):
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None,
final_layer=True,
dtype=None,
@ -314,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
@ -341,6 +363,9 @@ class QwenImageTransformer2DModel(nn.Module):
for _ in range(num_layers)
])
if self.default_ref_method == "index_timestep_zero":
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
if final_layer:
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
@ -391,11 +416,14 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents:
if index_ref_method:
index += 1
@ -415,6 +443,10 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -446,7 +478,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
@ -458,6 +490,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options,
)
@ -474,6 +507,9 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None:
hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5):

View File

@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
dit_config["meanflow_sum"] = False
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
@ -257,6 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
@ -612,6 +618,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

View File

@ -39,6 +39,7 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_offload_reserve_ram_gb, get_free_disk
from comfy.quant_ops import QuantizedTensor
@ -210,14 +211,17 @@ class LowVramPatch:
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
if model_dtype is None:
model_dtype = weight.dtype
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@ -528,6 +532,9 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
@ -741,12 +748,18 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
def check_module_offload_mem(key):
if key in self.patches:
return low_vram_patch_estimate_vram(self.model, key)
model_dtype = getattr(self.model, "manual_cast_dtype", None)
weight, _, _ = get_key_weight(self.model, key)
if model_dtype is None or weight is None:
return 0
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
return weight.numel() * model_dtype.itemsize
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
@ -1012,7 +1025,7 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
if cast_weight:
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False

View File

@ -22,7 +22,6 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
import json
def run_every_op():
@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
@ -505,15 +497,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self._has_bias = bias
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@ -538,7 +529,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
@ -568,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
requires_grad=False
)
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
@ -589,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
def _forward(self, input, weight, bias):

View File

@ -411,7 +411,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
orig_dtype = tensor.dtype
if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
tensor_info = torch.finfo(tensor.dtype)
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
if scale is not None:
if not isinstance(scale, torch.Tensor):

View File

@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models

View File

@ -720,7 +720,7 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",

View File

@ -127,6 +127,8 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
@ -547,8 +549,10 @@ class VAE:
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:

View File

@ -28,6 +28,7 @@ from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE):
unet_config = {
@ -541,7 +542,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.2
memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."]
@ -965,7 +966,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
@ -1026,9 +1027,15 @@ class ZImage(Lumina2):
"shift": 3.0,
}
memory_usage_factor = 1.7
memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@ -1289,7 +1296,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.038
memory_usage_factor = 0.044
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
@ -1332,7 +1339,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
memory_usage_factor = 1.95 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
@ -1397,7 +1404,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7
memory_usage_factor = 8.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1488,7 +1495,7 @@ class Kandinsky5(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.1 #TODO
memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1517,7 +1524,7 @@ class Kandinsky5Image(Kandinsky5):
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.1 #TODO
memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)

View File

@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None
return f.read(length_of_header)
ATTR_UNSET={}
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], value)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
else:
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
@ -1257,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata

View File

@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any, Dict
from typing import Any
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sockets_metadata: dict[str, dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
@ -42,7 +42,7 @@ def get_connection_feature(
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sockets_metadata: dict[str, dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
@ -60,7 +60,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]:
def get_server_features() -> dict[str, Any]:
"""
Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import Type, List, NamedTuple
from typing import NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: Type[ComfyAPIBase]
api_class: type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = []
registered_versions: list[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]):
def register_versions(versions: list[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]:
def get_all_versions() -> list[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""

View File

@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args, get_type_hints
from typing import Optional, get_origin, get_args, get_type_hints
class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
"""
Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod
def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker
cls, async_class: type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports
@classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub(
cls,
name: str,
attr: Type,
attr: type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed
@classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
"""
Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import Type, TypeVar
from typing import TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
)
return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None:
def inject_instance(cls: type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T:
def get_instance(cls: type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""

View File

@ -1,13 +1,13 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING
from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
This should be used to initialize any global resources needed by the extension.
"""
@abstractmethod
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch
from typing import TypedDict, List, Optional
from typing import TypedDict, Optional
ImageInput = torch.Tensor
"""
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[List[int]]
batch_index: Optional[list[int]]

View File

@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""

View File

@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:

View File

@ -26,7 +26,7 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
@ -774,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="TRACKS")
class Tracks(ComfyTypeIO):
class TrackDict(TypedDict):
track_path: torch.Tensor
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -1549,12 +1556,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None)
return type_clone
@final
@ -1815,7 +1822,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"]
if "expand" in data:
expand = data["expand"]
return cls(args=args, ui=ui, expand=expand)
return cls(*args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any:
return self.args[index]
@ -1894,6 +1901,7 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
"Tracks",
# Dynamic Types
"MatchType",
# "DynamicCombo",

View File

@ -5,7 +5,6 @@ import os
import random
import uuid
from io import BytesIO
from typing import Type
import av
import numpy as np
@ -22,7 +21,7 @@ import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
from ._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
@ -83,7 +82,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -96,7 +95,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -121,7 +120,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
@ -137,7 +136,7 @@ class ImageSaveHelper:
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -155,7 +154,7 @@ class ImageSaveHelper:
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
@ -169,7 +168,7 @@ class ImageSaveHelper:
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -191,7 +190,7 @@ class ImageSaveHelper:
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
@ -209,7 +208,7 @@ class ImageSaveHelper:
images,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -238,7 +237,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -267,7 +266,7 @@ class AudioSaveHelper:
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
@ -372,7 +371,7 @@ class AudioSaveHelper:
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
@ -388,7 +387,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import ImageInput, AudioInput
from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"

View File

@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [
supported_versions: list[type[ComfyAPIBase]] = [
ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1,

View File

@ -0,0 +1,144 @@
from typing import Literal
from pydantic import BaseModel, Field
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str = Field("url")
image: list[str] | None = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
"seedance-1-0-pro-fast-251015": {
"480p": 50,
"720p": 65,
"1080p": 100,
},
}

View File

@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
class GeminiFunctionDeclaration(BaseModel):

View File

@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class OmniTaskStatusResponseData(BaseModel):
class TaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: OmniTaskStatusResults | None = Field(None)
task_result: TaskStatusResults | None = Field(None)
class OmniTaskStatusResponse(BaseModel):
class TaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: OmniTaskStatusResponseData | None = Field(None)
data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel):
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")

View File

@ -0,0 +1,52 @@
from pydantic import BaseModel, Field
class Datum2(BaseModel):
b64_json: str | None = Field(None, description="Base64 encoded image data")
revised_prompt: str | None = Field(None, description="Revised prompt")
url: str | None = Field(None, description="URL of the image")
class InputTokensDetails(BaseModel):
image_tokens: int | None = None
text_tokens: int | None = None
class Usage(BaseModel):
input_tokens: int | None = None
input_tokens_details: InputTokensDetails | None = None
output_tokens: int | None = None
total_tokens: int | None = None
class OpenAIImageGenerationResponse(BaseModel):
data: list[Datum2] | None = None
usage: Usage | None = None
class OpenAIImageEditRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str = Field(...)
moderation: str | None = Field(None)
n: int | None = Field(None, description="The number of images to generate")
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
size: str | None = Field(None, description="Size of the output image")
class OpenAIImageGenerationRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str | None = Field(None)
moderation: str | None = Field(None)
n: int | None = Field(
None,
description="The number of images to generate.",
)
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="The quality of the generated image")
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")

View File

@ -1,100 +0,0 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned"
EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel):
root: Union[

View File

@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
videos: Optional[list[Video]] = Field(None)
class VeoGenVidPollResponse(BaseModel):

View File

@ -1,13 +1,27 @@
import logging
import math
from enum import Enum
from typing import Literal, Optional, Union
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_api import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
VIDEO_TASKS_EXECUTION_TIME,
Image2ImageTaskCreationRequest,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskCreationResponse,
TaskImageContent,
TaskImageContentUrl,
TaskStatusResponse,
TaskTextContent,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
class Text2ImageModelName(str, Enum):
seedream_3 = "seedream-3-0-t2i-250415"
class Image2ImageModelName(str, Enum):
seededit_3 = "seededit-3-0-i2i-250628"
class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
class Text2ImageTaskCreationRequest(BaseModel):
model: Text2ImageModelName = Text2ImageModelName.seedream_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
size: Optional[str] = Field(None)
seed: Optional[int] = Field(0, ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: Image2ImageModelName = Image2ImageModelName.seededit_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: Optional[str] = Field("adaptive")
seed: Optional[int] = Field(..., ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field("seedream-4-0-250828")
prompt: str = Field(...)
response_format: str = Field("url")
image: Optional[list[str]] = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: Optional[TaskStatusError] = Field(None)
content: Optional[TaskStatusResult] = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
}
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"]
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "content") and response.content:
return response.content.video_url
return None
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
IO.String.Input(
"prompt",
multiline=True,
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
"model",
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
IO.Image.Input(
"image",
tooltip="The base image to edit",
@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str,
seed: int,
guidance_scale: float,
@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedream-4-0-250828"],
options=["seedream-4-5-251128", "seedream-4-0-250828"],
tooltip="Model name",
),
IO.String.Input(
@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=64,
step=8,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor = None,
image: Input.Image | None = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048,
height: int = 2048,
@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0
if "seedream-4-5" in model and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
f"but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model and out_num_pixels < 921600:
raise ValueError(
f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided."
)
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
),
IO.String.Input(
"prompt",
@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: torch.Tensor,
image: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=[Image2VideoModelName.seedance_1_lite.value],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
),
IO.String.Input(
"prompt",
@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
images: torch.Tensor,
images: Input.Image,
resolution: str,
aspect_ratio: str,
duration: int,
@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task(
cls: type[IO.ComfyNode],
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
estimated_duration: Optional[int],
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
@ -1085,7 +934,7 @@ async def process_video_task(
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:

View File

@ -13,8 +13,7 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
@ -27,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType,
GeminiPart,
GeminiRole,
GeminiSystemInstructionContent,
GeminiTextPart,
Modality,
)
from comfy_api_nodes.util import (
@ -43,6 +44,14 @@ from comfy_api_nodes.util import (
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
"format, intent, or abstraction—as literal visual directives for image composition.\n"
"If a prompt is conversational or lacks specific visual details, "
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
class GeminiModel(str, Enum):
@ -68,7 +77,7 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: torch.Tensor,
images: Input.Image,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
@ -154,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
@ -277,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.String.Output(),
@ -293,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
@ -343,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@ -363,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
if files is not None:
parts.extend(files)
# Create response
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -373,7 +395,8 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user,
parts=parts,
)
]
],
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -523,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -542,10 +572,11 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@ -559,6 +590,10 @@ class GeminiImage(IO.ComfyNode):
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -570,6 +605,7 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -640,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -662,8 +705,9 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str,
resolution: str,
response_modalities: str,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -679,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -690,6 +738,7 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,

View File

@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import (
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
OmniTaskStatusResponse,
TaskStatusResponse,
TextToVideoWithAudioRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -103,10 +105,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -127,8 +125,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -242,7 +238,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -250,7 +246,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
@ -483,12 +479,12 @@ async def execute_image2video(
task_id = task_creation_response.data.task_id
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@ -752,7 +748,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[4],
default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -834,7 +830,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -929,7 +925,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
@ -997,7 +993,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1081,7 +1077,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1162,7 +1158,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1237,7 +1233,7 @@ class OmniProImageNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
@ -1253,7 +1249,7 @@ class OmniProImageNode(IO.ComfyNode):
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
@ -1328,9 +1324,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image to Video",
display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling",
description="Kling Image to Video Node",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1488,7 +1483,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[8],
default=modes[6],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -1951,7 +1946,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input(
"model_name",
options=[i.value for i in KlingImageGenModelName],
default="kling-v1",
default="kling-v2",
),
IO.Combo.Input(
"aspect_ratio",
@ -2034,6 +2029,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images))
class TextToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
mode: str,
aspect_ratio: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
response_model=TaskStatusResponse,
data=TextToVideoWithAudioRequest(
model_name=model_name,
prompt=prompt,
mode=mode,
aspect_ratio=aspect_ratio,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class ImageToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
start_frame: Input.Image,
prompt: str,
mode: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
response_model=TaskStatusResponse,
data=ImageToVideoWithAudioRequest(
model_name=model_name,
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
prompt=prompt,
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2056,7 +2181,9 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
# OmniProImageNode, # need support from backend
OmniProImageNode,
TextToVideoWithAudio,
ImageToVideoWithAudio,
]

View File

@ -1,12 +1,9 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
fps: int | None = Field(25)
generate_audio: bool | None = Field(True)
image_uri: str | None = Field(None)
class TextToVideoNode(IO.ComfyNode):
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
model: str,
prompt: str,
duration: int,
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):

View File

@ -1,11 +1,8 @@
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str,
negative_prompt: str,
seed: int,
video: Optional[VideoInput] = None,
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
motion_intensity: int | None = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:

View File

@ -1,46 +1,45 @@
from io import BytesIO
import base64
import os
from enum import Enum
from inspect import cleandoc
from io import BytesIO
import numpy as np
import torch
from PIL import Image
import folder_paths
import base64
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
OpenAICreateResponse,
OpenAIResponse,
CreateModelResponseProperties,
Item,
OutputContent,
InputImageContent,
Detail,
InputTextContent,
InputMessage,
InputMessageContentList,
InputContent,
InputFileContent,
InputImageContent,
InputMessage,
InputMessageContentList,
InputTextContent,
Item,
OpenAICreateResponse,
OpenAIResponse,
OutputContent,
)
from comfy_api_nodes.apis.openai_api import (
OpenAIImageEditRequest,
OpenAIImageGenerationRequest,
OpenAIImageGenerationResponse,
)
from comfy_api_nodes.util import (
downscale_image_tensor,
download_url_to_bytesio,
validate_string,
tensor_to_base64_string,
ApiEndpoint,
sync_op,
download_url_to_bytesio,
downscale_image_tensor,
poll_op,
sync_op,
tensor_to_base64_string,
text_filepath_to_data_uri,
validate_string,
)
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
class OpenAIDalle2(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
"""
@classmethod
def define_schema(cls):
@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode):
node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[
IO.String.Input(
"prompt",
@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode):
class OpenAIDalle3(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
"""
@classmethod
def define_schema(cls):
@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode):
node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[
IO.String.Input(
"prompt",
@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.NodeOutput(await validate_and_cast_response(response))
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
# https://platform.openai.com/docs/pricing
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
"""
@classmethod
def define_schema(cls):
@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt for GPT Image 1",
tooltip="Text prompt for GPT Image",
),
IO.Int.Input(
"seed",
@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
),
IO.Combo.Input(
"background",
default="opaque",
options=["opaque", "transparent"],
default="auto",
options=["auto", "opaque", "transparent"],
tooltip="Return image with or without background",
optional=True,
),
@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode):
tooltip="Optional mask for inpainting (white areas will be replaced)",
optional=True,
),
IO.Combo.Input(
"model",
options=["gpt-image-1", "gpt-image-1.5"],
optional=True,
),
],
outputs=[
IO.Image.Output(),
@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode):
@classmethod
async def execute(
cls,
prompt,
seed=0,
quality="low",
background="opaque",
image=None,
mask=None,
n=1,
size="1024x1024",
prompt: str,
seed: int = 0,
quality: str = "low",
background: str = "opaque",
image: Input.Image | None = None,
mask: Input.Image | None = None,
n: int = 1,
size: str = "1024x1024",
model: str = "gpt-image-1",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
files = []
if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image")
if model == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1
elif model == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5
else:
raise ValueError(f"Unknown model: {model}")
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
content_type = "multipart/form-data"
files = []
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i : i + 1]
scaled_image = downscale_image_tensor(single_image).squeeze()
single_image = image[i: i + 1]
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode):
else:
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if image is None:
raise Exception("Cannot use a mask without an input image")
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
batch, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
if mask is not None:
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
_, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
# Build the operation
response = await sync_op(
cls,
ApiEndpoint(path=path, method="POST"),
response_model=OpenAIImageGenerationResponse,
data=request_class(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
),
files=files if files else None,
content_type=content_type,
)
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageEditRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
content_type="multipart/form-data",
files=files,
price_extractor=price_extractor,
)
else:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
price_extractor=price_extractor,
)
return IO.NodeOutput(await validate_and_cast_response(response))

View File

@ -1,575 +0,0 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

View File

@ -11,12 +11,11 @@ User Guides:
"""
from typing import Union, Optional
from typing_extensions import override
from enum import Enum
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
sync_op,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
) -> float | None:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response(
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
@ -119,8 +116,8 @@ async def get_response(
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
estimated_duration: int | None = None,
) -> InputImpl.VideoFromFile:
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
start_frame: Input.Image,
end_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
cls,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)

View File

@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
) -> IO.NodeOutput:
@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode):
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
geometry_quality=geometry_quality,
auto_size=True,
quad=quad,
),
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
orientation=None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode):
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
options=[
"preset:idle",
"preset:walk",
"preset:run",
"preset:dive",
"preset:climb",
"preset:jump",
"preset:slash",
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hurt",
"preset:fall",
"preset:turn",
"preset:quadruped:walk",
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
],
),
],
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit",
default=-1,
min=-1,
max=500000,
max=2000000,
optional=True,
),
IO.Int.Input(
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
default="JPEG",
optional=True,
),
IO.Boolean.Input("force_symmetry", default=False, optional=True),
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
IO.Float.Input(
"flatten_bottom_threshold",
default=0.0,
min=0.0,
max=1.0,
optional=True,
),
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
IO.Float.Input(
"scale_factor",
default=1.0,
min=0.0,
optional=True,
),
IO.Boolean.Input("with_animation", default=False, optional=True),
IO.Boolean.Input("pack_uv", default=False, optional=True),
IO.Boolean.Input("bake", default=False, optional=True),
IO.String.Input("part_names", default="", optional=True), # comma-separated list
IO.Combo.Input(
"fbx_preset",
options=["blender", "mixamo", "3dsmax"],
default="blender",
optional=True,
),
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
IO.Combo.Input(
"export_orientation",
options=["align_image", "default"],
default="default",
optional=True,
),
IO.Boolean.Input("animate_in_place", default=False, optional=True),
],
outputs=[],
hidden=[
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id,
format: str,
quad: bool,
force_symmetry: bool,
face_limit: int,
flatten_bottom: bool,
flatten_bottom_threshold: float,
texture_size: int,
texture_format: str,
pivot_to_center_bottom: bool,
scale_factor: float,
with_animation: bool,
pack_uv: bool,
bake: bool,
part_names: str,
fbx_preset: str,
export_vertex_colors: bool,
export_orientation: str,
animate_in_place: bool,
) -> IO.NodeOutput:
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
force_symmetry=force_symmetry if force_symmetry else None,
face_limit=face_limit if face_limit != -1 else None,
flatten_bottom=flatten_bottom if flatten_bottom else None,
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None,
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
scale_factor=scale_factor if scale_factor != 1.0 else None,
with_animation=with_animation if with_animation else None,
pack_uv=pack_uv if pack_uv else None,
bake=bake if bake else None,
part_names=part_names_list,
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
export_orientation=export_orientation if export_orientation != "default" else None,
animate_in_place=animate_in_place if animate_in_place else None,
),
)
return await poll_until_finished(cls, response, average_duration=30)

View File

@ -1,11 +1,9 @@
import base64
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
seed: int,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
model: str,
generate_audio: bool,
):
@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")

View File

@ -1,7 +1,5 @@
import re
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
@ -21,26 +19,26 @@ from comfy_api_nodes.util import (
class Text2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
negative_prompt: str | None = Field(None)
class Image2ImageInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
negative_prompt: str | None = Field(None)
images: list[str] = Field(..., min_length=1, max_length=2)
class Text2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
audio_url: Optional[str] = Field(None)
negative_prompt: str | None = Field(None)
audio_url: str | None = Field(None)
class Image2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None)
negative_prompt: str | None = Field(None)
img_url: str = Field(...)
audio_url: Optional[str] = Field(None)
audio_url: str | None = Field(None)
class Txt2ImageParametersField(BaseModel):
@ -52,7 +50,7 @@ class Txt2ImageParametersField(BaseModel):
class Image2ImageParametersField(BaseModel):
size: Optional[str] = Field(None)
size: str | None = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
@ -61,19 +59,21 @@ class Image2ImageParametersField(BaseModel):
class Text2VideoParametersField(BaseModel):
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
audio: bool = Field(False, description="Should be audio generated automatically")
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
class Image2VideoParametersField(BaseModel):
resolution: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=10)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
audio: bool = Field(False, description="Should be audio generated automatically")
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
class Text2ImageTaskCreationRequest(BaseModel):
@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel):
class TaskCreationResponse(BaseModel):
output: Optional[TaskCreationOutputField] = Field(None)
output: TaskCreationOutputField | None = Field(None)
request_id: str = Field(...)
code: Optional[str] = Field(None, description="The error code of the failed request.")
message: Optional[str] = Field(None, description="Details of the failed request.")
code: str | None = Field(None, description="Error code for the failed request.")
message: str | None = Field(None, description="Details about the failed request.")
class TaskResult(BaseModel):
url: Optional[str] = Field(None)
code: Optional[str] = Field(None)
message: Optional[str] = Field(None)
url: str | None = Field(None)
code: str | None = Field(None)
message: str | None = Field(None)
class ImageTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
results: Optional[list[TaskResult]] = Field(None)
results: list[TaskResult] | None = Field(None)
class VideoTaskStatusOutputField(TaskCreationOutputField):
task_id: str = Field(...)
task_status: str = Field(...)
video_url: Optional[str] = Field(None)
code: Optional[str] = Field(None)
message: Optional[str] = Field(None)
video_url: str | None = Field(None)
code: str | None = Field(None)
message: str | None = Field(None)
class ImageTaskStatusResponse(BaseModel):
output: Optional[ImageTaskStatusOutputField] = Field(None)
output: ImageTaskStatusOutputField | None = Field(None)
request_id: str = Field(...)
class VideoTaskStatusResponse(BaseModel):
output: Optional[VideoTaskStatusOutputField] = Field(None)
output: VideoTaskStatusOutputField | None = Field(None)
request_id: str = Field(...)
@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode):
node_id="WanTextToImageApi",
display_name="Wan Text to Image",
category="api node/image/Wan",
description="Generates image based on text prompt.",
description="Generates an image based on a text prompt.",
inputs=[
IO.Combo.Input(
"model",
@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode):
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
tooltip="Negative prompt describing what to avoid.",
optional=True,
),
IO.Int.Input(
@ -209,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip='Whether to add an "AI generated" watermark to the result.',
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
],
@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode):
),
)
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
display_name="Wan Image to Image",
category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
inputs=[
IO.Combo.Input(
"model",
@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode):
),
IO.Image.Input(
"image",
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
tooltip="Single-image editing or multi-image fusion. Maximum 2 images.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
tooltip="Negative prompt describing what to avoid.",
optional=True,
),
# redo this later as an optional combo of recommended resolutions
@ -328,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip='Whether to add an "AI generated" watermark to the result.',
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
],
@ -347,7 +347,7 @@ class WanImageToImageApi(IO.ComfyNode):
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str,
negative_prompt: str = "",
# width: int = 1024,
@ -357,7 +357,7 @@ class WanImageToImageApi(IO.ComfyNode):
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.")
images = []
for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode):
),
)
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode):
node_id="WanTextToVideoApi",
display_name="Wan Text to Video",
category="api node/video/Wan",
description="Generates video based on text prompt.",
description="Generates a video based on a text prompt.",
inputs=[
IO.Combo.Input(
"model",
options=["wan2.5-t2v-preview"],
default="wan2.5-t2v-preview",
options=["wan2.5-t2v-preview", "wan2.6-t2v"],
default="wan2.6-t2v",
tooltip="Model to use.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
tooltip="Negative prompt describing what to avoid.",
optional=True,
),
IO.Combo.Input(
@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode):
"1080p: 4:3 (1632x1248)",
"1080p: 3:4 (1248x1632)",
],
default="480p: 1:1 (624x624)",
default="720p: 1:1 (960x960)",
optional=True,
),
IO.Int.Input(
"duration",
default=5,
min=5,
max=10,
max=15,
step=5,
display_mode=IO.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds",
tooltip="A 15-second duration is available only for the Wan 2.6 model.",
optional=True,
),
IO.Audio.Input(
"audio",
optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
),
IO.Int.Input(
"seed",
@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode):
"generate_audio",
default=False,
optional=True,
tooltip="If there is no audio input, generate audio automatically.",
tooltip="If no audio input is provided, generate audio automatically.",
),
IO.Boolean.Input(
"prompt_extend",
@ -477,7 +477,15 @@ class WanTextToVideoApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip='Whether to add an "AI generated" watermark to the result.',
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
IO.Combo.Input(
"shot_type",
options=["single", "multi"],
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
"single continuous shot or multiple shots with cuts. "
"This parameter takes effect only when prompt_extend is True.",
optional=True,
),
],
@ -498,14 +506,19 @@ class WanTextToVideoApi(IO.ComfyNode):
model: str,
prompt: str,
negative_prompt: str = "",
size: str = "480p: 1:1 (624x624)",
size: str = "720p: 1:1 (960x960)",
duration: int = 5,
audio: Optional[Input.Audio] = None,
audio: Input.Audio | None = None,
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
shot_type: str = "single",
):
if "480p" in size and model == "wan2.6-t2v":
raise ValueError("The Wan 2.6 model does not support 480p.")
if duration == 15 and model == "wan2.5-t2v-preview":
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
width, height = RES_IN_PARENS.search(size).groups()
audio_url = None
if audio is not None:
@ -526,11 +539,12 @@ class WanTextToVideoApi(IO.ComfyNode):
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
shot_type=shot_type,
),
),
)
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
@ -549,12 +563,12 @@ class WanImageToVideoApi(IO.ComfyNode):
node_id="WanImageToVideoApi",
display_name="Wan Image to Video",
category="api node/video/Wan",
description="Generates video based on the first frame and text prompt.",
description="Generates a video from the first frame and a text prompt.",
inputs=[
IO.Combo.Input(
"model",
options=["wan2.5-i2v-preview"],
default="wan2.5-i2v-preview",
options=["wan2.5-i2v-preview", "wan2.6-i2v"],
default="wan2.6-i2v",
tooltip="Model to use.",
),
IO.Image.Input(
@ -564,13 +578,13 @@ class WanImageToVideoApi(IO.ComfyNode):
"prompt",
multiline=True,
default="",
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid.",
tooltip="Negative prompt describing what to avoid.",
optional=True,
),
IO.Combo.Input(
@ -580,23 +594,23 @@ class WanImageToVideoApi(IO.ComfyNode):
"720P",
"1080P",
],
default="480P",
default="720P",
optional=True,
),
IO.Int.Input(
"duration",
default=5,
min=5,
max=10,
max=15,
step=5,
display_mode=IO.NumberDisplay.number,
tooltip="Available durations: 5 and 10 seconds",
tooltip="Duration 15 available only for WAN2.6 model.",
optional=True,
),
IO.Audio.Input(
"audio",
optional=True,
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
),
IO.Int.Input(
"seed",
@ -613,7 +627,7 @@ class WanImageToVideoApi(IO.ComfyNode):
"generate_audio",
default=False,
optional=True,
tooltip="If there is no audio input, generate audio automatically.",
tooltip="If no audio input is provided, generate audio automatically.",
),
IO.Boolean.Input(
"prompt_extend",
@ -624,7 +638,15 @@ class WanImageToVideoApi(IO.ComfyNode):
IO.Boolean.Input(
"watermark",
default=True,
tooltip='Whether to add an "AI generated" watermark to the result.',
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
IO.Combo.Input(
"shot_type",
options=["single", "multi"],
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
"single continuous shot or multiple shots with cuts. "
"This parameter takes effect only when prompt_extend is True.",
optional=True,
),
],
@ -643,19 +665,24 @@ class WanImageToVideoApi(IO.ComfyNode):
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str,
negative_prompt: str = "",
resolution: str = "480P",
resolution: str = "720P",
duration: int = 5,
audio: Optional[Input.Audio] = None,
audio: Input.Audio | None = None,
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
shot_type: str = "single",
):
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
if "480P" in resolution and model == "wan2.6-i2v":
raise ValueError("The Wan 2.6 model does not support 480P.")
if duration == 15 and model == "wan2.5-i2v-preview":
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
audio_url = None
if audio is not None:
@ -677,11 +704,12 @@ class WanImageToVideoApi(IO.ComfyNode):
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
shot_type=shot_type,
),
),
)
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),

View File

@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
return img_byte_arr
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)

291
comfy_execution/jobs.py Normal file
View File

@ -0,0 +1,291 @@
"""
Job utilities for the /api/jobs endpoint.
Provides normalization and helper functions for job status tracking.
"""
from typing import Optional
from comfy_api.internal import prune_dict
class JobStatus:
"""Job status constants."""
PENDING = 'pending'
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
FAILED = 'failed'
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
Returns:
tuple: (create_time, workflow_id)
"""
create_time = extra_data.get('create_time')
extra_pnginfo = extra_data.get('extra_pnginfo', {})
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
return create_time, workflow_id
def is_previewable(media_type: str, item: dict) -> bool:
"""
Check if an output item is previewable.
Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts
Maintains backwards compatibility with existing logic.
Priority:
1. media_type is 'images', 'video', or 'audio'
2. format field starts with 'video/' or 'audio/'
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
# Check format field (MIME type).
# Maintains backwards compatibility with how custom node outputs are handled in the frontend.
fmt = item.get('format', '')
if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')):
return True
# Check for 3D files by extension
filename = item.get('filename', '').lower()
if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS):
return True
return False
def normalize_queue_item(item: tuple, status: str) -> dict:
"""Convert queue item tuple to unified job dict.
Expects item with sensitive data already removed (5 elements).
"""
priority, prompt_id, _, extra_data, _ = item
create_time, workflow_id = _extract_job_metadata(extra_data)
return prune_dict({
'id': prompt_id,
'status': status,
'priority': priority,
'create_time': create_time,
'outputs_count': 0,
'workflow_id': workflow_id,
})
def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict:
"""Convert history item dict to unified job dict.
History items have sensitive data already removed (prompt tuple has 5 elements).
"""
prompt_tuple = history_item['prompt']
priority, _, prompt, extra_data, _ = prompt_tuple
create_time, workflow_id = _extract_job_metadata(extra_data)
status_info = history_item.get('status', {})
status_str = status_info.get('status_str') if status_info else None
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.FAILED
else:
status = JobStatus.COMPLETED
outputs = history_item.get('outputs', {})
outputs_count, preview_output = get_outputs_summary(outputs)
execution_error = None
execution_start_time = None
execution_end_time = None
if status_info:
messages = status_info.get('messages', [])
for entry in messages:
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
event_name, event_data = entry[0], entry[1]
if isinstance(event_data, dict):
if event_name == 'execution_start':
execution_start_time = event_data.get('timestamp')
elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'):
execution_end_time = event_data.get('timestamp')
if event_name == 'execution_error':
execution_error = event_data
job = prune_dict({
'id': prompt_id,
'status': status,
'priority': priority,
'create_time': create_time,
'execution_start_time': execution_start_time,
'execution_end_time': execution_end_time,
'execution_error': execution_error,
'outputs_count': outputs_count,
'preview_output': preview_output,
'workflow_id': workflow_id,
})
if include_outputs:
job['outputs'] = outputs
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
'extra_data': extra_data,
}
return job
def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
"""
Count outputs and find preview in a single pass.
Returns (outputs_count, preview_output).
Preview priority (matching frontend):
1. type="output" with previewable media
2. Any previewable media
"""
count = 0
preview_output = None
fallback_preview = None
for node_id, node_outputs in outputs.items():
if not isinstance(node_outputs, dict):
continue
for media_type, items in node_outputs.items():
# 'animated' is a boolean flag, not actual output items
if media_type == 'animated' or not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
count += 1
if preview_output is None and is_previewable(media_type, item):
enriched = {
**item,
'nodeId': node_id,
'mediaType': media_type
}
if item.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched
return count, preview_output or fallback_preview
def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]:
"""Sort jobs list by specified field and order."""
reverse = (sort_order == 'desc')
if sort_by == 'execution_duration':
def get_sort_key(job):
start = job.get('execution_start_time', 0)
end = job.get('execution_end_time', 0)
return end - start if end and start else 0
else:
def get_sort_key(job):
return job.get('create_time', 0)
return sorted(jobs, key=get_sort_key, reverse=reverse)
def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]:
"""
Get a single job by prompt_id from history or queue.
Args:
prompt_id: The prompt ID to look up
running: List of currently running queue items
queued: List of pending queue items
history: Dict of history items keyed by prompt_id
Returns:
Job dict with full details, or None if not found
"""
if prompt_id in history:
return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True)
for item in running:
if item[1] == prompt_id:
return normalize_queue_item(item, JobStatus.IN_PROGRESS)
for item in queued:
if item[1] == prompt_id:
return normalize_queue_item(item, JobStatus.PENDING)
return None
def get_all_jobs(
running: list,
queued: list,
history: dict,
status_filter: Optional[list[str]] = None,
workflow_id: Optional[str] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
limit: Optional[int] = None,
offset: int = 0
) -> tuple[list[dict], int]:
"""
Get all jobs (running, pending, completed) with filtering and sorting.
Args:
running: List of currently running queue items
queued: List of pending queue items
history: Dict of history items keyed by prompt_id
status_filter: List of statuses to include (from JobStatus.ALL)
workflow_id: Filter by workflow ID
sort_by: Field to sort by ('created_at', 'execution_duration')
sort_order: 'asc' or 'desc'
limit: Maximum number of items to return
offset: Number of items to skip
Returns:
tuple: (jobs_list, total_count)
"""
jobs = []
if status_filter is None:
status_filter = JobStatus.ALL
if JobStatus.IN_PROGRESS in status_filter:
for item in running:
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
if JobStatus.PENDING in status_filter:
for item in queued:
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
include_completed = JobStatus.COMPLETED in status_filter
include_failed = JobStatus.FAILED in status_filter
if include_completed or include_failed:
for prompt_id, history_item in history.items():
is_failed = history_item.get('status', {}).get('status_str') == 'error'
if (is_failed and include_failed) or (not is_failed and include_completed):
jobs.append(normalize_history_item(prompt_id, history_item))
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
jobs = apply_sorting(jobs, sort_by, sort_order)
total_count = len(jobs)
if offset > 0:
jobs = jobs[offset:]
if limit is not None:
jobs = jobs[:limit]
return (jobs, total_count)

View File

@ -659,6 +659,40 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()],
description=(
"This sampler node can represent multiple samplers:\n\n"
"seeds_2\n"
"- default setting\n\n"
"exp_heun_2_x0\n"
"- solver_type=phi_2, r=1.0, eta=0.0\n\n"
"exp_heun_2_x0_sde\n"
"- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0"
)
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@ -996,6 +1030,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,

View File

@ -1125,6 +1125,99 @@ class MergeTextListsNode(TextProcessingNode):
# ========== Training Dataset Nodes ==========
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
display_name="Resolution Bucket",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts to bucket by resolution.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists (must match latents length).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of batched latent dicts, one per resolution bucket.",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of condition lists, one per resolution bucket.",
),
],
)
@classmethod
def execute(cls, latents, conditioning):
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
# conditioning: list[list[cond]]
# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)
# Flatten latents and conditions to individual samples
flat_latents = [] # list of (C, H, W) tensors
flat_conditions = [] # list of condition lists
for latent_dict, cond in zip(latents, conditioning):
samples = latent_dict["samples"] # (B, C, H, W)
batch_size = samples.shape[0]
# cond is a list of conditions with length == batch_size
for i in range(batch_size):
flat_latents.append(samples[i]) # (C, H, W)
flat_conditions.append(cond[i]) # single condition
# Group by resolution (H, W)
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
h, w = latent.shape[-2], latent.shape[-1]
key = (h, w)
if key not in buckets:
buckets[key] = {"latents": [], "conditions": []}
buckets[key]["latents"].append(latent)
buckets[key]["conditions"].append(cond)
# Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents})
# Conditions stay as list of condition lists
output_conditions.append(bucket_data["conditions"])
logging.info(
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
)
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
return io.NodeOutput(output_latents, output_conditions)
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@ -1373,7 +1466,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = torch.load(f, weights_only=True)
shard_data = torch.load(f)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
@ -1425,6 +1518,7 @@ class DatasetExtension(ComfyExtension):
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
ResolutionBucket,
]

View File

@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
display_name="Edit Model Reference Method",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
),
],
outputs=[

View File

@ -243,7 +243,16 @@ class ModelPatchLoader:
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd)
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
config = {}
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
config['n_control_layers'] = 15
config['additional_in_dim'] = 17
config['refiner_control'] = True
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
if ref_weight is not None:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@ -297,62 +306,122 @@ class DiffSynthCnetPatch:
return [self.model_patch]
class ZImageControlPatch:
def __init__(self, model_patch, vae, image, strength):
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.inpaint_image = inpaint_image
self.mask = mask
self.strength = strength
self.encoded_image = self.encode_latent_cond(image)
self.encoded_image_size = (image.shape[1], image.shape[2])
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
skip_encoding = False
if self.image is not None and self.inpaint_image is not None:
if self.image.shape != self.inpaint_image.shape:
skip_encoding = True
if skip_encoding:
self.encoded_image = None
else:
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
if self.image is None:
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
else:
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
self.temp_data = None
def encode_latent_cond(self, image):
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
return latent_image
def encode_latent_cond(self, control_image=None, inpaint_image=None):
latent_image = None
if control_image is not None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
if self.is_inpaint:
if inpaint_image is None:
inpaint_image = torch.ones_like(control_image) * 0.5
if self.mask is not None:
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
if self.mask is None:
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
else:
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
if latent_image is None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
else:
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
img_input = kwargs.get("img_input")
txt = kwargs.get("txt")
pe = kwargs.get("pe")
vec = kwargs.get("vec")
block_index = kwargs.get("block_index")
block_type = kwargs.get("block_type", "")
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
image_scaled = None
if self.image is not None:
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
inpaint_scaled = None
if self.inpaint_image is not None:
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
comfy.model_management.load_models_gpu(loaded_models)
cnet_index = (block_index // 5)
cnet_index_float = (block_index / 5)
cnet_blocks = self.model_patch.model.n_control_layers
div = round(30 / cnet_blocks)
cnet_index = (block_index // div)
cnet_index_float = (block_index / div)
kwargs.pop("img") # we do ops in place
kwargs.pop("txt")
cnet_blocks = self.model_patch.model.n_control_layers
if cnet_index_float > (cnet_blocks - 1):
self.temp_data = None
return kwargs
if self.temp_data is None or self.temp_data[0] > cnet_index:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
if block_type == "noise_refiner":
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
else:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
if block_type == "noise_refiner":
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if self.temp_data[1][0] is not None:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
else:
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if cnet_index_float == self.temp_data[0]:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
if cnet_blocks == self.temp_data[0] + 1:
self.temp_data = None
if cnet_index_float == self.temp_data[0]:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
if cnet_blocks == self.temp_data[0] + 1:
self.temp_data = None
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
if self.encoded_image is not None:
self.encoded_image = self.encoded_image.to(device_or_dtype)
self.temp_data = None
return self
@ -375,9 +444,12 @@ class QwenImageDiffsynthControlnet:
CATEGORY = "advanced/loaders/qwen"
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
model_patched = model.clone()
image = image[:, :, :, :3]
if image is not None:
image = image[:, :, :, :3]
if inpaint_image is not None:
inpaint_image = inpaint_image[:, :, :, :3]
if mask is not None:
if mask.ndim == 3:
mask = mask.unsqueeze(1)
@ -386,11 +458,24 @@ class QwenImageDiffsynthControlnet:
mask = 1.0 - mask
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
model_patched.set_model_noise_refiner_patch(patch)
model_patched.set_model_double_block_patch(patch)
else:
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
CATEGORY = "advanced/loaders/zimage"
class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
@ -438,5 +523,6 @@ class USOStyleReference:
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"ZImageFunControlnet": ZImageFunControlnet,
"USOStyleReference": USOStyleReference,
}

View File

@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
],
outputs=[
io.Image.Output(),
@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
)
@classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024)
total = megapixels * 1024 * 1024
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
s = s.movedim(1,-1)
return io.NodeOutput(s)

View File

@ -2,6 +2,8 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper
def skip_torch_compile_dict(guard_entries):
return [("transformer_options" not in entry.name) for entry in guard_entries]
class TorchCompileModel(io.ComfyNode):
@classmethod
@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m)

View File

@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override
import comfy.samplers
import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def outer_sample(
self,
noise,
latent_image,
sampler,
sigmas,
denoise_mask=None,
callback=None,
disable_pbar=False,
seed=None,
latent_shapes=None,
):
self.inner_model, self.conds, self.loaded_models = (
comfy.sampler_helpers.prepare_sampling(
self.model_patcher,
noise.shape,
self.conds,
self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
)
)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(
denoise_mask, noise.shape, device
)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
comfy.samplers.cast_to_load_options(
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
)
try:
self.model_patcher.pre_run()
output = self.inner_sample(
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
latent_shapes=latent_shapes,
)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.loaded_models
return output
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler):
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler):
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def _init_bucket_data(self, bucket_latents):
"""Initialize bucket offsets and weights for sampling."""
self.bucket_offsets = [0]
bucket_sizes = []
for lat in bucket_latents:
bucket_sizes.append(lat.shape[0])
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
def fwd_bwd(
self,
@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler):
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
"""Generate random sigma values for a batch."""
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(batch_size)
]
return torch.tensor(batch_sigmas).to(device)
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
"""Execute one training step in bucket mode."""
# Sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0]
bucket_offset = self.bucket_offsets[bucket_idx]
# Sample indices from this bucket (use all if bucket_size < batch_size)
actual_batch_size = min(self.batch_size, bucket_size)
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond, # Use flattened cond with absolute indices
absolute_indices,
extra_args,
self.num_images,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
def sample(
self,
model_wrap,
@ -142,70 +330,18 @@ class TrainSampler(comfy.samplers.Sampler):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
if self.real_dataset is None:
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
if self.bucket_latents is not None:
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
elif self.real_dataset is None:
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
else:
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
ui_pbar.update(1)
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
@ -283,6 +419,364 @@ def unpatch(m):
del m.org_forward
def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training.
Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
Returns:
list of latent tensors
"""
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents
def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training.
Args:
latents: list of latent dicts or single latent dict
Returns:
Processed latents (tensor or list of tensors)
"""
if len(latents) == 1:
return latents[0]["samples"] # Single latent dict
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
return latent_list
def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists.
Args:
positive: list of conditioning
Returns:
Flattened conditioning list
"""
if len(positive) == 1:
return positive[0] # Single conditioning list
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
return flat_positive
def _prepare_latents_and_count(latents, dtype, bucket_mode):
"""Convert latents to dtype and compute image counts.
Args:
latents: Latents (tensor, list of tensors, or bucket list)
dtype: Target dtype
bucket_mode: Whether bucket mode is enabled
Returns:
tuple: (processed_latents, num_images, multi_res)
"""
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
multi_res = False
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_images = 0
multi_res = False
return latents, num_images, multi_res
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
"""Validate conditioning count matches image count, expand if needed.
Args:
positive: Conditioning list
num_images: Number of images
bucket_mode: Whether bucket mode is enabled
Returns:
Validated/expanded conditioning list
Raises:
ValueError: If conditioning count doesn't match image count
"""
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
return positive
def _load_existing_lora(existing_lora):
"""Load existing LoRA weights if provided.
Args:
existing_lora: LoRA filename or "[None]"
Returns:
tuple: (existing_weights dict, existing_steps int)
"""
if existing_lora == "[None]":
return {}, 0
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
existing_weights = {}
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
return existing_weights, existing_steps
def _create_weight_adapter(
module, module_name, existing_weights, algorithm, lora_dtype, rank
):
"""Create a weight adapter for a module with weight.
Args:
module: The module to create adapter for
module_name: Name of the module
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (train_adapter, lora_params dict)
"""
key = f"{module_name}.weight"
shape = module.weight.shape
lora_params = {}
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
# Try to load existing adapter
existing_adapter = None
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
module_name, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
if existing_adapter is None:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
module.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter.train().requires_grad_(True), lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff).train().requires_grad_(True)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
def _create_bias_adapter(module, module_name, lora_dtype):
"""Create a bias adapter for a module with bias.
Args:
module: The module with bias
module_name: Name of the module
lora_dtype: dtype for LoRA weights
Returns:
tuple: (bias_module, lora_params dict)
"""
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias).train().requires_grad_(True)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup all LoRA adapters on the model.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list)
"""
lora_sd = {}
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
key = f"{n}.weight"
mp.add_weight_wrapper(key, adapter)
all_weight_adapters.append(adapter)
if hasattr(m, "bias") and m.bias is not None:
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
return lora_sd, all_weight_adapters
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
Args:
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
parameters: Parameters to optimize
learning_rate: Learning rate
Returns:
Optimizer instance
"""
if optimizer_name == "Adam":
return torch.optim.Adam(parameters, lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(parameters, lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(parameters, lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(parameters, lr=learning_rate)
def _create_loss_function(loss_function_name):
"""Create loss function based on name.
Args:
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
Returns:
Loss function instance
"""
if loss_function_name == "MSE":
return torch.nn.MSELoss()
elif loss_function_name == "L1":
return torch.nn.L1Loss()
elif loss_function_name == "Huber":
return torch.nn.HuberLoss()
elif loss_function_name == "SmoothL1":
return torch.nn.SmoothL1Loss()
def _run_training_loop(
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
):
"""Execute the training loop.
Args:
guider: The guider object
train_sampler: The training sampler
latents: Latent tensors
num_images: Number of images
seed: Random seed
bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled
"""
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode):
default="[None]",
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
io.Boolean.Input(
"bucket_mode",
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
],
outputs=[
io.Model.Output(
@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode):
algorithm,
gradient_checkpointing,
existing_lora,
bucket_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -427,215 +927,125 @@ class TrainLoraNode(io.ComfyNode):
grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0]
rank = rank[0]
optimizer = optimizer[0]
loss_function = loss_function[0]
optimizer_name = optimizer[0]
loss_function_name = loss_function[0]
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
# Handle latents - either single dict or list of dicts
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
latents = _process_latents_standard_mode(latents)
# Handle conditioning - either single list or list of lists
if len(positive) == 1:
positive = positive[0] # Single conditioning list
else:
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
positive = flat_positive
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
# latents here can be list of different size latent or one large batch
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
else:
logging.error(f"Invalid latents type: {type(latents)}")
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
)
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Setup models for training
mp.model.requires_grad_(False)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
existing_weights, existing_steps = _load_existing_lora(existing_lora)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
existing_adapter = None
adapter_cls = adapter_maps[algorithm]
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(
lora_dtype
)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
# Create optimizer and loss function
optimizer = _create_optimizer(
optimizer_name, lora_sd.values(), learning_rate
)
criterion = _create_loss_function(loss_function_name)
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(
m.bias.shape, dtype=lora_dtype, requires_grad=True
)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
mp.model.requires_grad_(False)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
)
torch.cuda.empty_cache()
# Setup sampler and guider like in test script
# Setup loss tracking
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# Create sampler
if bucket_mode:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
)
else:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
# Training loop
# Setup guider
guider = TrainGuider(mp)
guider.set_conds(positive)
# Run training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
_run_training_loop(
guider,
train_sampler,
sigmas,
seed=noise.seed,
latents,
num_images,
seed,
bucket_mode,
multi_res,
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
@ -645,7 +1055,7 @@ class TrainLoraNode(io.ComfyNode):
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):
class LoraModelLoader(io.ComfyNode):#
@classmethod
def define_schema(cls):
return io.Schema(

View File

@ -8,10 +8,7 @@ import json
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
from comfy.cli_args import args
class SaveWEBM(io.ComfyNode):
@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=VideoContainer(format),
format=Types.VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
)
@classmethod
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
class GetVideoComponents(io.ComfyNode):
@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
)
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components()
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
@classmethod
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):

View File

@ -0,0 +1,535 @@
import nodes
import node_helpers
import torch
import torchvision.transforms.functional as TF
import comfy.model_management
import comfy.utils
import numpy as np
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_wan import parse_json_tracks
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
from PIL import Image, ImageDraw
SKIP_ZERO = False
def get_pos_emb(
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
pos_emb_dim: int,
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
pos_k = pos_k.to(device, dtype)
if SKIP_ZERO:
pos_k = pos_k + 1
batch_size = pos_k.size(0)
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
# Expand denominator to match the shape needed for broadcasting
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
thetas = theta_func(denominator_expanded, pos_emb_dim)
# Ensure pos_k is in the correct shape for broadcasting
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
# Concatenate sine and cosine embeddings along the last dimension
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
return pos_emb
def create_pos_embeddings(
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
height: int, # the height of the feature map
width: int, # the width of the feature map
track_num: int = -1, # the number of tracks to use
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
):
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
t, n, _ = pred_tracks.shape
t_down, h_down, w_down = downsample_ratios
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
if track_num == -1:
track_num = n
tracks_idx = torch.randperm(n)[:track_num]
tracks = pred_tracks[:, tracks_idx]
visibility = pred_visibility[:, tracks_idx]
for t_idx in range(0, t, t_down):
if t_down_strategy == "sample" or t_idx == 0:
cur_tracks = tracks[t_idx] # [N, 2]
cur_visibility = visibility[t_idx] # [N]
else:
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
for i in range(track_num):
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
continue
x, y = cur_tracks[i]
x, y = int(x // w_down), int(y // h_down)
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
def replace_feature(
vae_feature: torch.Tensor, # [B, C', T', H', W']
track_pos: torch.Tensor, # [B, N, T', 2]
strength: float = 1.0
) -> torch.Tensor:
b, _, t, h, w = vae_feature.shape
assert b == track_pos.shape[0], "Batch size mismatch."
n = track_pos.shape[1]
# Shuffle the trajectory order
track_pos = track_pos[:, torch.randperm(n), :, :]
# Extract coordinates at time steps ≥ 1 and generate a valid mask
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
# Get all valid indices
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
num_valid = valid_indices.shape[0]
if num_valid == 0:
return vae_feature
# Decompose valid indices into each dimension
batch_idx = valid_indices[:, 0]
track_idx = valid_indices[:, 1]
t_rel = valid_indices[:, 2]
t_target = t_rel + 1 # Convert to original time step indices
# Extract target position coordinates
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
# Extract source position coordinates (t=0)
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
# Get source features and assign to target positions
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
return vae_feature
# Visualize functions
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
draw = ImageDraw.Draw(overlay, 'RGBA')
points = points[::-1]
# Compute total length
total_length = 0
segment_lengths = []
for i in range(len(points) - 1):
dx = points[i + 1][0] - points[i][0]
dy = points[i + 1][1] - points[i][1]
length = (dx * dx + dy * dy) ** 0.5
segment_lengths.append(length)
total_length += length
if total_length == 0:
return
accumulated_length = 0
# Draw the gradient polyline
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
segment_length = segment_lengths[idx]
steps = max(int(segment_length), 1)
for i in range(steps):
current_length = accumulated_length + (i / steps) * segment_length
ratio = current_length / total_length
alpha = int(255 * (1 - ratio) * opacity)
color = (*start_color, alpha)
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
accumulated_length += segment_length
def add_weighted(rgb, track):
rgb = np.array(rgb) # [H, W, C] "RGB"
track = np.array(track) # [H, W, C] "RGBA"
alpha = track[:, :, 3] / 255.0
alpha = np.stack([alpha] * 3, axis=-1)
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
return Image.fromarray(blend_img.astype(np.uint8))
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
tracks = tracks[0].long().detach().cpu().numpy()
if visibility is not None:
visibility = visibility[0].detach().cpu().numpy()
num_frames, height, width = video.shape[:3]
num_tracks = tracks.shape[1]
alpha_opacity = int(255 * opacity)
output_frames = []
for t in range(num_frames):
frame_rgb = video[t].astype(np.float32)
# Create a single RGBA overlay for all tracks in this frame
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
draw_overlay = ImageDraw.Draw(overlay)
polyline_data = []
# Draw all circles on a single overlay
for n in range(num_tracks):
if visibility is not None and visibility[t, n] == 0:
continue
track_coord = tracks[t, n]
color = color_map[n % len(color_map)]
circle_color = color + (alpha_opacity,)
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
fill=circle_color
)
# Store polyline data for batch processing
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
if len(tracks_coord) > 1:
polyline_data.append((tracks_coord, color))
# Blend circles overlay once
overlay_np = np.array(overlay)
alpha = overlay_np[:, :, 3:4] / 255.0
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
# Draw all polylines on a single overlay
if polyline_data:
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
for tracks_coord, color in polyline_data:
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
# Blend polylines overlay once
polyline_np = np.array(polyline_overlay)
alpha = polyline_np[:, :, 3:4] / 255.0
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
return output_frames
class WanMoveVisualizeTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveVisualizeTracks",
category="conditioning/video_models",
inputs=[
io.Image.Input("images"),
io.Tracks.Input("tracks", optional=True),
io.Int.Input("line_resolution", default=24, min=1, max=1024),
io.Int.Input("circle_size", default=12, min=1, max=128),
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
io.Int.Input("line_width", default=16, min=1, max=128),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
if tracks is None:
return io.NodeOutput(images)
track_path = tracks["track_path"].unsqueeze(0)
track_visibility = tracks["track_visibility"].unsqueeze(0)
images_in = images * 255.0
if images_in.shape[0] != track_path.shape[1]:
repeat_count = track_path.shape[1] // images.shape[0]
images_in = images_in.repeat(repeat_count, 1, 1, 1)
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
class WanMoveTracksFromCoords(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTracksFromCoords",
category="conditioning/video_models",
inputs=[
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
io.Mask.Input("track_mask", optional=True),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
tracks_data = parse_json_tracks(track_coords)
track_length = len(tracks_data[0])
track_list = [
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
for frame in range(len(tracks_data[0]))
]
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
num_tracks = tracks.shape[-2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class GenerateTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GenerateTracks",
category="conditioning/video_models",
inputs=[
io.Int.Input("width", default=832, min=16, max=4096, step=16),
io.Int.Input("height", default=480, min=16, max=4096, step=16),
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
io.Int.Input("num_frames", default=81, min=1, max=1024),
io.Int.Input("num_tracks", default=5, min=1, max=100),
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Combo.Input(
"interpolation",
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
tooltip="Controls the timing/speed of movement along the path.",
),
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
device = comfy.model_management.intermediate_device()
track_length = num_frames
# normalized coordinates to pixel coordinates
start_x_px = start_x * width
start_y_px = start_y * height
mid_x_px = mid_x * width
mid_y_px = mid_y * height
end_x_px = end_x * width
end_y_px = end_y * height
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
t = torch.linspace(0, 1, num_frames, device=device)
if interpolation == "constant": # All points stay at start position
interp_values = torch.zeros_like(t)
elif interpolation == "linear":
interp_values = t
elif interpolation == "ease_in":
interp_values = t ** 2
elif interpolation == "ease_out":
interp_values = 1 - (1 - t) ** 2
elif interpolation == "ease_in_out":
interp_values = t * t * (3 - 2 * t)
if bezier: # apply interpolation to t for timing control along the bezier path
t_interp = interp_values
one_minus_t = 1 - t_interp
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
else: # calculate base x and y positions for each frame (center track)
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
# For non-bezier, tangent is constant (direction from start to end)
tangent_x = torch.full_like(t, end_x_px - start_x_px)
tangent_y = torch.full_like(t, end_y_px - start_y_px)
track_list = []
for frame_idx in range(num_frames):
# Calculate perpendicular direction at this frame
tx = tangent_x[frame_idx].item()
ty = tangent_y[frame_idx].item()
length = (tx ** 2 + ty ** 2) ** 0.5
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
perp_x = -ty / length
perp_y = tx / length
else: # If tangent is zero, spread horizontally
perp_x = 1.0
perp_y = 0.0
frame_tracks = []
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
track_x = x_positions[frame_idx].item() + perp_x * offset
track_y = y_positions[frame_idx].item() + perp_y * offset
frame_tracks.append([track_x, track_y])
track_list.append(frame_tracks)
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class WanMoveConcatTrack(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveConcatTrack",
category="conditioning/video_models",
inputs=[
io.Tracks.Input("tracks_1"),
io.Tracks.Input("tracks_2", optional=True),
],
outputs=[
io.Tracks.Output(),
],
)
@classmethod
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
if tracks_2 is None:
return io.NodeOutput(tracks_1)
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
out_track_info = {}
out_track_info["track_path"] = tracks_out
out_track_info["track_visibility"] = mask_out
return io.NodeOutput(out_track_info)
class WanMoveTrackToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Tracks.Input("tracks", optional=True),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image"),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if tracks is not None and strength > 0.0:
tracks_path = tracks["track_path"][:length] # [T, N, 2]
num_tracks = tracks_path.shape[-2]
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
else:
concat_latent_image_pos = concat_latent_image
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanMoveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanMoveTrackToVideo,
WanMoveTracksFromCoords,
WanMoveConcatTrack,
WanMoveVisualizeTracks,
GenerateTracks,
]
async def comfy_entrypoint() -> WanMoveExtension:
return WanMoveExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.76"
__version__ = "0.5.1"

View File

@ -13,6 +13,7 @@ import asyncio
import torch
import comfy.model_management
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
BasicCache,
@ -669,6 +670,8 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False)
if "client_id" in extra_data:

View File

@ -8,6 +8,8 @@ import folder_paths
import comfy.utils
import logging
default_preview_method = args.preview_method
MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None):
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback
def set_preview_method(override: str = None):
if override and override != "default":
method = LatentPreviewMethod.from_string(override)
if method is not None:
args.preview_method = method
return
args.preview_method = default_preview_method

66
main.py
View File

@ -23,6 +23,38 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
if __name__ == "__main__":
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
if args.default_device is not None:
default_dev = args.default_device
devices = list(range(32))
devices.remove(default_dev)
devices.insert(0, default_dev)
devices = ','.join(map(str, devices))
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None:
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc
if "rocm" in cuda_malloc.get_torch_version_noimport():
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
def handle_comfyui_manager_unavailable():
if not args.windows_standalone_build:
@ -137,40 +169,6 @@ import shutil
import threading
import gc
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
if __name__ == "__main__":
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
if args.default_device is not None:
default_dev = args.default_device
devices = list(range(32))
devices.remove(default_dev)
devices.insert(0, default_dev)
devices = ','.join(map(str, devices))
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None:
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc
if "rocm" in cuda_malloc.get_torch_version_noimport():
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")

View File

@ -1 +1 @@
comfyui_manager==4.0.3b4
comfyui_manager==4.0.3b5

View File

@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes():
"nodes_logic.py",
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
]
import_failed = []
@ -2383,7 +2384,6 @@ async def init_builtin_api_nodes():
"nodes_recraft.py",
"nodes_pixverse.py",
"nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py",
"nodes_sora.py",
"nodes_topaz.py",

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.76"
version = "0.5.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.10
comfyui-workflow-templates==0.7.25
comfyui-frontend-package==1.34.9
comfyui-workflow-templates==0.7.59
comfyui-embedded-docs==0.3.1
torch
torchsde

135
server.py
View File

@ -7,6 +7,7 @@ import time
import nodes
import folder_paths
import execution
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
import uuid
import urllib
import json
@ -47,6 +48,12 @@ from middleware.cache_middleware import cache_control
if args.enable_manager:
import comfyui_manager
def _remove_sensitive_from_queue(queue: list) -> list:
"""Remove sensitive data (index 5) from queue item tuples."""
return [item[:5] for item in queue]
async def send_socket_catch_exception(function, message):
try:
await function(message)
@ -694,6 +701,129 @@ class PromptServer():
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.
Query parameters:
status: Filter by status (comma-separated): pending, in_progress, completed, failed
workflow_id: Filter by workflow ID
sort_by: Sort field: created_at (default), execution_duration
sort_order: Sort direction: asc, desc (default)
limit: Max items to return (positive integer)
offset: Items to skip (non-negative integer, default 0)
"""
query = request.rel_url.query
status_param = query.get('status')
workflow_id = query.get('workflow_id')
sort_by = query.get('sort_by', 'created_at').lower()
sort_order = query.get('sort_order', 'desc').lower()
status_filter = None
if status_param:
status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()]
invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL]
if invalid_statuses:
return web.json_response(
{"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"},
status=400
)
if sort_by not in {'created_at', 'execution_duration'}:
return web.json_response(
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
status=400
)
if sort_order not in {'asc', 'desc'}:
return web.json_response(
{"error": "sort_order must be 'asc' or 'desc'"},
status=400
)
limit = None
# If limit is provided, validate that it is a positive integer, else continue without a limit
if 'limit' in query:
try:
limit = int(query.get('limit'))
if limit <= 0:
return web.json_response(
{"error": "limit must be a positive integer"},
status=400
)
except (ValueError, TypeError):
return web.json_response(
{"error": "limit must be an integer"},
status=400
)
offset = 0
if 'offset' in query:
try:
offset = int(query.get('offset'))
if offset < 0:
offset = 0
except (ValueError, TypeError):
return web.json_response(
{"error": "offset must be an integer"},
status=400
)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history()
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
jobs, total = get_all_jobs(
running, queued, history,
status_filter=status_filter,
workflow_id=workflow_id,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset
)
has_more = (offset + len(jobs)) < total
return web.json_response({
'jobs': jobs,
'pagination': {
'offset': offset,
'limit': limit,
'total': total,
'has_more': has_more
}
})
@routes.get("/api/jobs/{job_id}")
async def get_job_by_id(request):
"""Get a single job by ID."""
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response(
{"error": "job_id is required"},
status=400
)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history(prompt_id=job_id)
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
job = get_job(job_id, running, queued, history)
if job is None:
return web.json_response(
{"error": "Job not found"},
status=404
)
return web.json_response(job)
@routes.get("/history")
async def get_history(request):
max_items = request.rel_url.query.get("max_items", None)
@ -717,9 +847,8 @@ class PromptServer():
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
remove_sensitive = lambda queue: [x[:5] for x in queue]
queue_info['queue_running'] = remove_sensitive(current_queue[0])
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
return web.json_response(queue_info)
@routes.post("/prompt")

View File

@ -0,0 +1,352 @@
"""
Unit tests for Queue-specific Preview Method Override feature.
Tests the preview method override functionality:
- LatentPreviewMethod.from_string() method
- set_preview_method() function in latent_preview.py
- default_preview_method variable
- Integration with args.preview_method
"""
import pytest
from comfy.cli_args import args, LatentPreviewMethod
from latent_preview import set_preview_method, default_preview_method
class TestLatentPreviewMethodFromString:
"""Test LatentPreviewMethod.from_string() classmethod."""
@pytest.mark.parametrize("value,expected", [
("auto", LatentPreviewMethod.Auto),
("latent2rgb", LatentPreviewMethod.Latent2RGB),
("taesd", LatentPreviewMethod.TAESD),
("none", LatentPreviewMethod.NoPreviews),
])
def test_valid_values_return_enum(self, value, expected):
"""Valid string values should return corresponding enum."""
assert LatentPreviewMethod.from_string(value) == expected
@pytest.mark.parametrize("invalid", [
"invalid",
"TAESD", # Case sensitive
"AUTO", # Case sensitive
"Latent2RGB", # Case sensitive
"latent",
"",
"default", # default is special, not a method
])
def test_invalid_values_return_none(self, invalid):
"""Invalid string values should return None."""
assert LatentPreviewMethod.from_string(invalid) is None
class TestLatentPreviewMethodEnumValues:
"""Test LatentPreviewMethod enum has expected values."""
def test_enum_values(self):
"""Verify enum values match expected strings."""
assert LatentPreviewMethod.NoPreviews.value == "none"
assert LatentPreviewMethod.Auto.value == "auto"
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
assert LatentPreviewMethod.TAESD.value == "taesd"
def test_enum_count(self):
"""Verify exactly 4 preview methods exist."""
assert len(LatentPreviewMethod) == 4
class TestSetPreviewMethod:
"""Test set_preview_method() function from latent_preview.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_override_with_taesd(self):
"""'taesd' should set args.preview_method to TAESD."""
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
def test_override_with_latent2rgb(self):
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_override_with_auto(self):
"""'auto' should set args.preview_method to Auto."""
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
def test_override_with_none_value(self):
"""'none' should set args.preview_method to NoPreviews."""
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_default_restores_original(self):
"""'default' should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use 'default' to restore
set_preview_method("default")
assert args.preview_method == default_preview_method
def test_none_param_restores_original(self):
"""None parameter should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use None to restore
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_empty_string_restores_original(self):
"""Empty string should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("")
assert args.preview_method == default_preview_method
def test_invalid_value_restores_original(self):
"""Invalid value should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("invalid_method")
assert args.preview_method == default_preview_method
def test_case_sensitive_invalid_restores(self):
"""Case-mismatched values should restore to default."""
set_preview_method("taesd")
set_preview_method("TAESD") # Wrong case
assert args.preview_method == default_preview_method
class TestDefaultPreviewMethod:
"""Test default_preview_method module variable."""
def test_default_is_not_none(self):
"""default_preview_method should not be None."""
assert default_preview_method is not None
def test_default_is_enum_member(self):
"""default_preview_method should be a LatentPreviewMethod enum."""
assert isinstance(default_preview_method, LatentPreviewMethod)
def test_default_matches_args_initial(self):
"""default_preview_method should match CLI default or user setting."""
# This tests that default_preview_method was captured at module load
# After set_preview_method(None), args should equal default
original = args.preview_method
set_preview_method("taesd")
set_preview_method(None)
assert args.preview_method == default_preview_method
args.preview_method = original
class TestArgsPreviewMethodModification:
"""Test args.preview_method can be modified correctly."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_args_accepts_all_enum_values(self):
"""args.preview_method should accept all LatentPreviewMethod values."""
for method in LatentPreviewMethod:
args.preview_method = method
assert args.preview_method == method
def test_args_modification_and_restoration(self):
"""args.preview_method should be modifiable and restorable."""
original = args.preview_method
args.preview_method = LatentPreviewMethod.TAESD
assert args.preview_method == LatentPreviewMethod.TAESD
args.preview_method = original
assert args.preview_method == original
class TestExecutionFlow:
"""Test the execution flow pattern used in execution.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_sequential_executions_with_different_methods(self):
"""Simulate multiple queue executions with different preview methods."""
# Execution 1: taesd
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Execution 2: none
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Execution 3: default (restore)
set_preview_method("default")
assert args.preview_method == default_preview_method
# Execution 4: auto
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
# Execution 5: no override (None)
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_override_then_default_pattern(self):
"""Test the pattern: override -> execute -> next call restores."""
# First execution with override
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
# Second execution without override restores default
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_extra_data_simulation(self):
"""Simulate extra_data.get('preview_method') patterns."""
# Simulate: extra_data = {"preview_method": "taesd"}
extra_data = {"preview_method": "taesd"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Simulate: extra_data = {}
extra_data = {}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
# Simulate: extra_data = {"preview_method": "default"}
extra_data = {"preview_method": "default"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
class TestRealWorldScenarios:
"""Tests using real-world prompt data patterns."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_captured_prompt_without_preview_method(self):
"""
Test with captured prompt that has no preview_method.
Based on: tests-unit/execution_test/fixtures/default_prompt.json
"""
# Real captured extra_data structure (preview_method absent)
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"create_time": 1765416558179
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_captured_prompt_with_preview_method_taesd(self):
"""Test captured prompt with preview_method: taesd."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"preview_method": "taesd"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
def test_captured_prompt_with_preview_method_none(self):
"""Test captured prompt with preview_method: none (disable preview)."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "none"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_captured_prompt_with_preview_method_latent2rgb(self):
"""Test captured prompt with preview_method: latent2rgb."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "latent2rgb"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_captured_prompt_with_preview_method_auto(self):
"""Test captured prompt with preview_method: auto."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "auto"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Auto
def test_captured_prompt_with_preview_method_default(self):
"""Test captured prompt with preview_method: default (use CLI setting)."""
# First set to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then simulate a prompt with "default"
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "default"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_sequential_queue_with_different_preview_methods(self):
"""
Simulate real queue scenario: multiple prompts with different settings.
This tests the actual usage pattern in ComfyUI.
"""
# Queue 1: User wants TAESD preview
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
set_preview_method(extra_data_1.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Queue 2: User wants no preview (faster execution)
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
set_preview_method(extra_data_2.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Queue 3: User doesn't specify (use server default)
extra_data_3 = {"client_id": "client-3"}
set_preview_method(extra_data_3.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 4: User explicitly wants default
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
set_preview_method(extra_data_4.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 5: User wants latent2rgb
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
set_preview_method(extra_data_5.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB

View File

@ -99,6 +99,37 @@ class ComfyClient:
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
def get_jobs(self, status=None, limit=None, offset=None, sort_by=None, sort_order=None):
url = "http://{}/api/jobs".format(self.server_address)
params = {}
if status is not None:
params["status"] = status
if limit is not None:
params["limit"] = limit
if offset is not None:
params["offset"] = offset
if sort_by is not None:
params["sort_by"] = sort_by
if sort_order is not None:
params["sort_order"] = sort_order
if params:
url_values = urllib.parse.urlencode(params)
url = "{}?{}".format(url, url_values)
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
def get_job(self, job_id):
url = "http://{}/api/jobs/{}".format(self.server_address, job_id)
try:
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
except urllib.error.HTTPError as e:
if e.code == 404:
return None
raise
def set_test_name(self, name):
self.test_name = name
@ -877,3 +908,106 @@ class TestExecution:
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
# Jobs API tests
def test_jobs_api_job_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that job objects have required fields"""
self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
job = jobs_response["jobs"][0]
assert "id" in job, "Job should have id"
assert "status" in job, "Job should have status"
assert "create_time" in job, "Job should have create_time"
assert "outputs_count" in job, "Job should have outputs_count"
assert "preview_output" in job, "Job should have preview_output"
def test_jobs_api_preview_output_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that preview_output has correct structure"""
self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
job = jobs_response["jobs"][0]
if job["preview_output"] is not None:
preview = job["preview_output"]
assert "filename" in preview, "Preview should have filename"
assert "nodeId" in preview, "Preview should have nodeId"
assert "mediaType" in preview, "Preview should have mediaType"
def test_jobs_api_pagination(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API pagination"""
for _ in range(5):
self._create_history_item(client, builder)
first_page = client.get_jobs(limit=2, offset=0)
second_page = client.get_jobs(limit=2, offset=2)
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
first_ids = {j["id"] for j in first_page["jobs"]}
second_ids = {j["id"] for j in second_page["jobs"]}
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
def test_jobs_api_sorting(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API sorting"""
for _ in range(3):
self._create_history_item(client, builder)
desc_jobs = client.get_jobs(sort_order="desc")
asc_jobs = client.get_jobs(sort_order="asc")
if len(desc_jobs["jobs"]) >= 2:
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
if len(desc_times) >= 2:
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
if len(asc_times) >= 2:
assert asc_times == sorted(asc_times), "Asc should be oldest first"
def test_jobs_api_status_filter(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API status filtering"""
self._create_history_item(client, builder)
completed_jobs = client.get_jobs(status="completed")
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
for job in completed_jobs["jobs"]:
assert job["status"] == "completed", "Should only return completed jobs"
# Pending jobs are transient - just verify filter doesn't error
pending_jobs = client.get_jobs(status="pending")
for job in pending_jobs["jobs"]:
assert job["status"] == "pending", "Should only return pending jobs"
def test_get_job_by_id(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test getting a single job by ID"""
result = self._create_history_item(client, builder)
prompt_id = result.get_prompt_id()
job = client.get_job(prompt_id)
assert job is not None, "Should find the job"
assert job["id"] == prompt_id, "Job ID should match"
assert "outputs" in job, "Single job should include outputs"
def test_get_job_not_found(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test getting a non-existent job returns 404"""
job = client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"

View File

@ -0,0 +1,361 @@
"""Unit tests for comfy_execution/jobs.py"""
from comfy_execution.jobs import (
JobStatus,
is_previewable,
normalize_queue_item,
normalize_history_item,
get_outputs_summary,
apply_sorting,
)
class TestJobStatus:
"""Test JobStatus constants."""
def test_status_values(self):
"""Status constants should have expected string values."""
assert JobStatus.PENDING == 'pending'
assert JobStatus.IN_PROGRESS == 'in_progress'
assert JobStatus.COMPLETED == 'completed'
assert JobStatus.FAILED == 'failed'
def test_all_contains_all_statuses(self):
"""ALL should contain all status values."""
assert JobStatus.PENDING in JobStatus.ALL
assert JobStatus.IN_PROGRESS in JobStatus.ALL
assert JobStatus.COMPLETED in JobStatus.ALL
assert JobStatus.FAILED in JobStatus.ALL
assert len(JobStatus.ALL) == 4
class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
"""Other media types should not be previewable."""
for media_type in ['latents', 'text', 'metadata', 'files']:
assert is_previewable(media_type, {}) is False
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
def test_3d_extensions_case_insensitive(self):
"""3D extension check should be case insensitive."""
item = {'filename': 'MODEL.GLB'}
assert is_previewable('files', item) is True
def test_video_format_previewable(self):
"""Items with video/ format should be previewable."""
item = {'format': 'video/mp4'}
assert is_previewable('files', item) is True
def test_audio_format_previewable(self):
"""Items with audio/ format should be previewable."""
item = {'format': 'audio/wav'}
assert is_previewable('files', item) is True
def test_other_format_not_previewable(self):
"""Items with other format should not be previewable."""
item = {'format': 'application/json'}
assert is_previewable('files', item) is False
class TestGetOutputsSummary:
"""Unit tests for get_outputs_summary()"""
def test_empty_outputs(self):
"""Empty outputs should return 0 count and None preview."""
count, preview = get_outputs_summary({})
assert count == 0
assert preview is None
def test_counts_across_multiple_nodes(self):
"""Outputs from multiple nodes should all be counted."""
outputs = {
'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]},
'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]},
'node3': {'images': [
{'filename': 'c.png', 'type': 'output'},
{'filename': 'd.png', 'type': 'output'}
]}
}
count, preview = get_outputs_summary(outputs)
assert count == 4
def test_skips_animated_key_and_non_list_values(self):
"""The 'animated' key and non-list values should be skipped."""
outputs = {
'node1': {
'images': [{'filename': 'test.png', 'type': 'output'}],
'animated': [True], # Should skip due to key name
'metadata': 'string', # Should skip due to non-list
'count': 42 # Should skip due to non-list
}
}
count, preview = get_outputs_summary(outputs)
assert count == 1
def test_preview_prefers_type_output(self):
"""Items with type='output' should be preferred for preview."""
outputs = {
'node1': {
'images': [
{'filename': 'temp.png', 'type': 'temp'},
{'filename': 'output.png', 'type': 'output'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 2
assert preview['filename'] == 'output.png'
def test_preview_fallback_when_no_output_type(self):
"""If no type='output', should use first previewable."""
outputs = {
'node1': {
'images': [
{'filename': 'temp1.png', 'type': 'temp'},
{'filename': 'temp2.png', 'type': 'temp'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert preview['filename'] == 'temp1.png'
def test_non_previewable_media_types_counted_but_no_preview(self):
"""Non-previewable media types should be counted but not used as preview."""
outputs = {
'node1': {
'latents': [
{'filename': 'latent1.safetensors'},
{'filename': 'latent2.safetensors'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 2
assert preview is None
def test_previewable_media_types(self):
"""Images, video, and audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
outputs = {
'node1': {
media_type: [{'filename': 'test.file', 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"{media_type} should be previewable"
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"3D file {ext} should be previewable"
def test_format_mime_type_previewable(self):
"""Files with video/ or audio/ format should be previewable."""
for fmt in ['video/x-custom', 'audio/x-custom']:
outputs = {
'node1': {
'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"Format {fmt} should be previewable"
def test_preview_enriched_with_node_metadata(self):
"""Preview should include nodeId, mediaType, and original fields."""
outputs = {
'node123': {
'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview['nodeId'] == 'node123'
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
class TestApplySorting:
"""Unit tests for apply_sorting()"""
def test_sort_by_create_time_desc(self):
"""Default sort by create_time descending."""
jobs = [
{'id': 'a', 'create_time': 100},
{'id': 'b', 'create_time': 300},
{'id': 'c', 'create_time': 200},
]
result = apply_sorting(jobs, 'created_at', 'desc')
assert [j['id'] for j in result] == ['b', 'c', 'a']
def test_sort_by_create_time_asc(self):
"""Sort by create_time ascending."""
jobs = [
{'id': 'a', 'create_time': 100},
{'id': 'b', 'create_time': 300},
{'id': 'c', 'create_time': 200},
]
result = apply_sorting(jobs, 'created_at', 'asc')
assert [j['id'] for j in result] == ['a', 'c', 'b']
def test_sort_by_execution_duration(self):
"""Sort by execution_duration should order by duration."""
jobs = [
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s
{'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s
]
result = apply_sorting(jobs, 'execution_duration', 'desc')
assert [j['id'] for j in result] == ['a', 'c', 'b']
def test_sort_with_none_values(self):
"""Jobs with None values should sort as 0."""
jobs = [
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100},
{'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None},
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200},
]
result = apply_sorting(jobs, 'execution_duration', 'asc')
assert result[0]['id'] == 'b' # None treated as 0, comes first
class TestNormalizeQueueItem:
"""Unit tests for normalize_queue_item()"""
def test_basic_normalization(self):
"""Queue item should be normalized to job dict."""
item = (
10, # priority/number
'prompt-123', # prompt_id
{'nodes': {}}, # prompt
{
'create_time': 1234567890,
'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}}
}, # extra_data
['node1'], # outputs_to_execute
)
job = normalize_queue_item(item, JobStatus.PENDING)
assert job['id'] == 'prompt-123'
assert job['status'] == 'pending'
assert job['priority'] == 10
assert job['create_time'] == 1234567890
assert 'execution_start_time' not in job
assert 'execution_end_time' not in job
assert 'execution_error' not in job
assert 'preview_output' not in job
assert job['outputs_count'] == 0
assert job['workflow_id'] == 'workflow-abc'
class TestNormalizeHistoryItem:
"""Unit tests for normalize_history_item()"""
def test_completed_job(self):
"""Completed history item should have correct status and times from messages."""
history_item = {
'prompt': (
5, # priority
'prompt-456',
{'nodes': {}},
{
'create_time': 1234567890000,
'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}}
},
['node1'],
),
'status': {
'status_str': 'success',
'completed': True,
'messages': [
('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}),
('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}),
]
},
'outputs': {},
}
job = normalize_history_item('prompt-456', history_item)
assert job['id'] == 'prompt-456'
assert job['status'] == 'completed'
assert job['priority'] == 5
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567893000
assert job['workflow_id'] == 'workflow-xyz'
def test_failed_job(self):
"""Failed history item should have failed status and error from messages."""
history_item = {
'prompt': (
5,
'prompt-789',
{'nodes': {}},
{'create_time': 1234567890000},
['node1'],
),
'status': {
'status_str': 'error',
'completed': False,
'messages': [
('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}),
('execution_error', {
'prompt_id': 'prompt-789',
'node_id': '5',
'node_type': 'KSampler',
'exception_message': 'CUDA out of memory',
'exception_type': 'RuntimeError',
'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'],
'timestamp': 1234567891000,
})
]
},
'outputs': {},
}
job = normalize_history_item('prompt-789', history_item)
assert job['status'] == 'failed'
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567891000
assert job['execution_error']['node_id'] == '5'
assert job['execution_error']['node_type'] == 'KSampler'
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
def test_include_outputs(self):
"""When include_outputs=True, should include full output data."""
history_item = {
'prompt': (
5,
'prompt-123',
{'nodes': {'1': {}}},
{'create_time': 1234567890, 'client_id': 'abc'},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {'node1': {'images': [{'filename': 'test.png'}]}},
}
job = normalize_history_item('prompt-123', history_item, include_outputs=True)
assert 'outputs' in job
assert 'workflow' in job
assert 'execution_status' in job
assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}}
assert job['workflow'] == {
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}

View File

@ -0,0 +1,358 @@
"""
E2E tests for Queue-specific Preview Method Override feature.
Tests actual execution with different preview_method values.
Requires a running ComfyUI server with models.
Usage:
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
Note:
These tests execute actual image generation and wait for completion.
Tests verify preview image transmission based on preview_method setting.
"""
import os
import json
import pytest
import uuid
import time
import random
import websocket
import urllib.request
from pathlib import Path
# Server configuration
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
# Use existing inference graph fixture
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
def is_server_running() -> bool:
"""Check if ComfyUI server is running."""
try:
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
with urllib.request.urlopen(request, timeout=2.0):
return True
except Exception:
return False
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
"""Prepare graph for testing: randomize seeds and reduce steps."""
adapted = json.loads(json.dumps(graph)) # Deep copy
for node_id, node in adapted.items():
inputs = node.get("inputs", {})
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
if "seed" in inputs:
inputs["seed"] = random.randint(0, 2**32 - 1)
if "noise_seed" in inputs:
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
# Reduce steps for faster testing (default 20 -> 5)
if "steps" in inputs:
inputs["steps"] = steps
return adapted
# Alias for backward compatibility
randomize_seed = prepare_graph_for_test
class PreviewMethodClient:
"""Client for testing preview_method with WebSocket execution tracking."""
def __init__(self, server_address: str):
self.server_address = server_address
self.client_id = str(uuid.uuid4())
self.ws = None
def connect(self):
"""Connect to WebSocket."""
self.ws = websocket.WebSocket()
self.ws.settimeout(120) # 2 minute timeout for sampling
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
def close(self):
"""Close WebSocket connection."""
if self.ws:
self.ws.close()
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
"""Queue a prompt and return response with prompt_id."""
data = {
"prompt": prompt,
"client_id": self.client_id,
"extra_data": extra_data or {}
}
req = urllib.request.Request(
f"http://{self.server_address}/prompt",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"}
)
return json.loads(urllib.request.urlopen(req).read())
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
"""
Wait for execution to complete via WebSocket.
Returns:
dict with keys: completed, error, preview_count, execution_time
"""
result = {
"completed": False,
"error": None,
"preview_count": 0,
"execution_time": 0.0
}
start_time = time.time()
self.ws.settimeout(timeout)
try:
while True:
out = self.ws.recv()
elapsed = time.time() - start_time
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data", {})
if data.get("prompt_id") != prompt_id:
continue
if msg_type == "executing":
if data.get("node") is None:
# Execution complete
result["completed"] = True
result["execution_time"] = elapsed
break
elif msg_type == "execution_error":
result["error"] = data
result["execution_time"] = elapsed
break
elif msg_type == "progress":
# Progress update during sampling
pass
elif isinstance(out, bytes):
# Binary data = preview image
result["preview_count"] += 1
except websocket.WebSocketTimeoutException:
result["error"] = "Timeout waiting for execution"
result["execution_time"] = time.time() - start_time
return result
def load_graph() -> dict:
"""Load the SDXL graph fixture with randomized seed."""
with open(GRAPH_FILE) as f:
graph = json.load(f)
return randomize_seed(graph) # Avoid caching
# Skip all tests if server is not running
pytestmark = [
pytest.mark.skipif(
not is_server_running(),
reason=f"ComfyUI server not running at {SERVER_URL}"
),
pytest.mark.preview_method,
pytest.mark.execution,
]
@pytest.fixture
def client():
"""Create and connect a test client."""
c = PreviewMethodClient(SERVER_HOST)
c.connect()
yield c
c.close()
@pytest.fixture
def graph():
"""Load the test graph."""
return load_graph()
class TestPreviewMethodExecution:
"""Test actual execution with different preview methods."""
def test_execution_with_latent2rgb(self, client, graph):
"""
Execute with preview_method=latent2rgb.
Should complete and potentially receive preview images.
"""
extra_data = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
# Should complete (may error if model missing, but that's separate)
assert result["completed"] or result["error"] is not None
# Execution should take some time (sampling)
if result["completed"]:
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
# latent2rgb should produce previews
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_taesd(self, client, graph):
"""
Execute with preview_method=taesd.
TAESD provides higher quality previews.
"""
extra_data = {"preview_method": "taesd"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
assert result["execution_time"] > 0.5
# taesd should also produce previews
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_none_preview(self, client, graph):
"""
Execute with preview_method=none.
No preview images should be generated.
"""
extra_data = {"preview_method": "none"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
# With "none", should receive no preview images
assert result["preview_count"] == 0, \
f"Expected no previews with 'none', got {result['preview_count']}"
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_default(self, client, graph):
"""
Execute with preview_method=default.
Should use server's CLI default setting.
"""
extra_data = {"preview_method": "default"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_without_preview_method(self, client, graph):
"""
Execute without preview_method in extra_data.
Should use server's default preview method.
"""
extra_data = {} # No preview_method
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
class TestPreviewMethodComparison:
"""Compare preview behavior between different methods."""
def test_none_vs_latent2rgb_preview_count(self, client, graph):
"""
Compare preview counts: 'none' should have 0, others should have >0.
This is the key verification that preview_method actually works.
"""
results = {}
# Run with none (randomize seed to avoid caching)
graph_none = randomize_seed(graph)
extra_data_none = {"preview_method": "none"}
response = client.queue_prompt(graph_none, extra_data_none)
results["none"] = client.wait_for_execution(response["prompt_id"])
# Run with latent2rgb (randomize seed again)
graph_rgb = randomize_seed(graph)
extra_data_rgb = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph_rgb, extra_data_rgb)
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
# Verify both completed
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
# Key assertion: 'none' should have 0 previews
assert results["none"]["preview_count"] == 0, \
f"'none' should have 0 previews, got {results['none']['preview_count']}"
# 'latent2rgb' should have at least 1 preview (depends on steps)
assert results["latent2rgb"]["preview_count"] > 0, \
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
print("\nPreview count comparison:") # noqa: T201
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
class TestPreviewMethodSequential:
"""Test sequential execution with different preview methods."""
def test_sequential_different_methods(self, client, graph):
"""
Execute multiple prompts sequentially with different preview methods.
Each should complete independently with correct preview behavior.
"""
methods = ["latent2rgb", "none", "default"]
results = []
for method in methods:
# Randomize seed for each execution to avoid caching
graph_run = randomize_seed(graph)
extra_data = {"preview_method": method}
response = client.queue_prompt(graph_run, extra_data)
result = client.wait_for_execution(response["prompt_id"])
results.append({
"method": method,
"completed": result["completed"],
"preview_count": result["preview_count"],
"execution_time": result["execution_time"],
"error": result["error"]
})
# All should complete or have clear errors
for r in results:
assert r["completed"] or r["error"] is not None, \
f"Method {r['method']} neither completed nor errored"
# "none" should have zero previews if completed
none_result = next(r for r in results if r["method"] == "none")
if none_result["completed"]:
assert none_result["preview_count"] == 0, \
f"'none' should have 0 previews, got {none_result['preview_count']}"
print("\nSequential execution results:") # noqa: T201
for r in results:
status = "" if r["completed"] else f"✗ ({r['error']})"
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201