Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2025-12-26 14:36:35 -08:00
commit 1b29ff09da
42 changed files with 2457 additions and 500 deletions

View File

@ -1,6 +1,6 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.5.0"
__version__ = "0.6.0"
# This deals with workspace issues
from comfy_compatibility.workspace import auto_patch_workspace_and_restart

View File

@ -33,6 +33,7 @@ from typing_extensions import NamedTuple
from comfy_api import feature_flags
from comfy_api.internal import _ComfyNodeInternal
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
from .latent_preview_image_encoding import encode_preview_image
from .. import __version__
from .. import interruption, model_management
@ -82,6 +83,21 @@ LOADED_MODULE_DIRS = {}
# logger.warning("ComfyUI Manager not found but enabled in args.")
def _remove_sensitive_from_queue(queue: list) -> list:
"""Remove sensitive data (index 5) from queue item tuples."""
items = []
for item in queue:
if isinstance(item, tuple) or isinstance(item, list) and len(item) >= 5:
items.append(item[:5])
else:
items.append({
**item,
"sensitive": None,
})
return items
async def send_socket_catch_exception(function, message):
try:
await function(message)
@ -759,6 +775,129 @@ class PromptServer(ExecutorToClientProgress):
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/api/jobs")
async def get_jobs(request):
"""List all jobs with filtering, sorting, and pagination.
Query parameters:
status: Filter by status (comma-separated): pending, in_progress, completed, failed
workflow_id: Filter by workflow ID
sort_by: Sort field: created_at (default), execution_duration
sort_order: Sort direction: asc, desc (default)
limit: Max items to return (positive integer)
offset: Items to skip (non-negative integer, default 0)
"""
query = request.rel_url.query
status_param = query.get('status')
workflow_id = query.get('workflow_id')
sort_by = query.get('sort_by', 'created_at').lower()
sort_order = query.get('sort_order', 'desc').lower()
status_filter = None
if status_param:
status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()]
invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL]
if invalid_statuses:
return web.json_response(
{"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"},
status=400
)
if sort_by not in {'created_at', 'execution_duration'}:
return web.json_response(
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
status=400
)
if sort_order not in {'asc', 'desc'}:
return web.json_response(
{"error": "sort_order must be 'asc' or 'desc'"},
status=400
)
limit = None
# If limit is provided, validate that it is a positive integer, else continue without a limit
if 'limit' in query:
try:
limit = int(query.get('limit'))
if limit <= 0:
return web.json_response(
{"error": "limit must be a positive integer"},
status=400
)
except (ValueError, TypeError):
return web.json_response(
{"error": "limit must be an integer"},
status=400
)
offset = 0
if 'offset' in query:
try:
offset = int(query.get('offset'))
if offset < 0:
offset = 0
except (ValueError, TypeError):
return web.json_response(
{"error": "offset must be an integer"},
status=400
)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history()
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
jobs, total = get_all_jobs(
running, queued, history,
status_filter=status_filter,
workflow_id=workflow_id,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset
)
has_more = (offset + len(jobs)) < total
return web.json_response({
'jobs': jobs,
'pagination': {
'offset': offset,
'limit': limit,
'total': total,
'has_more': has_more
}
})
@routes.get("/api/jobs/{job_id}")
async def get_job_by_id(request):
"""Get a single job by ID."""
job_id = request.match_info.get("job_id", None)
if not job_id:
return web.json_response(
{"error": "job_id is required"},
status=400
)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history(prompt_id=job_id)
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
job = get_job(job_id, running, queued, history)
if job is None:
return web.json_response(
{"error": "Job not found"},
status=404
)
return web.json_response(job)
@routes.get("/history")
async def get_history(request):
max_items = request.rel_url.query.get("max_items", None)
@ -782,18 +921,8 @@ class PromptServer(ExecutorToClientProgress):
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
def remove_sensitive(queue: List[QueueItem]):
items = []
for item in queue:
items.append({
**item,
"sensitive": None,
})
return items
queue_info['queue_running'] = remove_sensitive(current_queue[0])
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
return web.json_response(queue_info)
@routes.post("/prompt")

View File

@ -152,7 +152,7 @@ class IndexListContextHandler(ContextHandlerABC):
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logger.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
logger.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:

View File

@ -1819,7 +1819,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Predictor
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
x_pred = denoised
else:
tau_t = tau_func(sigmas[i + 1])
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
@ -1840,7 +1840,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
if tau_t > 0 and s_noise > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
x_pred = x_pred + noise
return x
return x_pred
@torch.no_grad()

View File

@ -643,7 +643,7 @@ class NextDiT(nn.Module):
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))

View File

@ -62,7 +62,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
@ -73,9 +73,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations
)
def forward(self, timestep, hidden_states):
self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb
@ -325,10 +335,10 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",image_model=None,
final_layer=True, dtype=None,
final_layer=True,use_additional_t_cond=False, dtype=None,
device=None,
operations=None,
):
@ -345,6 +355,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype,
device=device,
operations=operations
@ -378,29 +389,35 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size
hidden_states = pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options=None, **kwargs):
if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
def _forward(
self,
@ -408,8 +425,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
ref_latents = None,
additional_t_cond=None,
transformer_options=None,
control=None,
**kwargs
@ -430,12 +447,17 @@ class QwenImageTransformer2DModel(nn.Module):
index = 0
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
@ -465,14 +487,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
@ -521,6 +536,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
input_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
output_channels=3,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
CausalConv3d(out_dim, output_channels, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
image_channels=3,
dropout=0.0):
super().__init__()
self.dim = dim
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def encode(self, x):

View File

@ -1154,7 +1154,7 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
if clip_text_pooled is not None:
out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled)
return out

View File

@ -439,8 +439,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None:
if ctd_weight is not None: # NewBie
dit_config["clip_text_dim"] = ctd_weight.shape[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
@ -629,6 +630,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

View File

@ -19,6 +19,7 @@ from __future__ import annotations
from .cmd.main_pre import tracer
import os
import gc
import logging
import platform
@ -383,13 +384,15 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logger.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logger.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:

View File

@ -371,10 +371,9 @@ class VAEEncode:
def encode(self, vae: VAE, pixels) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
t = vae.encode(pixels[:, :, :, :3])
t = vae.encode(pixels)
return (Latent(**{"samples": t}),)
class VAEEncodeTiled:
@classmethod
def INPUT_TYPES(s):
@ -393,10 +392,9 @@ class VAEEncodeTiled:
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8) -> tuple[Optional[Latent]]:
if pixels is None:
return None,
t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return (Latent(**{"samples": t}),)
class VAEEncodeForInpaint:
@classmethod
def INPUT_TYPES(s):
@ -1055,7 +1053,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": {"clip_name1": (get_filename_list_with_downloadable("text_encoders"),), "clip_name2": (
get_filename_list_with_downloadable("text_encoders"),),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"],),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"],),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -1066,7 +1064,7 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(sd.CLIPType, type.upper(), sd.CLIPType.STABLE_DIFFUSION)

View File

@ -133,21 +133,21 @@ def estimate_memory(model, noise_shape, conds):
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
executor = patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models

View File

@ -1031,9 +1031,6 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
@ -1060,6 +1057,24 @@ class CFGGuider:
else:
latent_shapes = [latent_image.shape]
if denoise_mask is not None:
if denoise_mask.is_nested:
denoise_masks = denoise_mask.unbind()
denoise_masks = denoise_masks[:len(latent_shapes)]
else:
denoise_masks = [denoise_mask]
for i in range(len(denoise_masks), len(latent_shapes)):
denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

View File

@ -66,6 +66,8 @@ from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import wan
from .text_encoders import z_image
from .text_encoders import jina_clip_2
from .text_encoders import newbie
from .utils import ProgressBar, FileMetadata, state_dict_prefix_replace
from .taesd.taehv import TAEHV
from .latent_formats import HunyuanVideo15, HunyuanVideo
@ -337,6 +339,7 @@ class VAE:
self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
@ -451,6 +454,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
@ -562,7 +566,9 @@ class VAE:
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = wan_vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2] <= 4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
@ -598,6 +604,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
@ -730,17 +737,28 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
if not self.crop_input:
return pixels
if self.crop_input:
downscale_ratio = self.spacial_compression_encode()
downscale_ratio = self.spacial_compression_encode()
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
if pixels.shape[-1] > self.output_channels:
pixels = pixels[..., :self.output_channels]
elif pixels.shape[-1] < self.output_channels:
if self.pad_channel_value is not None:
if isinstance(self.pad_channel_value, str):
mode = self.pad_channel_value
value = None
else:
mode = "constant"
value = self.pad_channel_value
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
@ -1048,6 +1066,7 @@ class CLIPType(Enum):
OVIS = 21
KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23
NEWBIE = 24
@dataclasses.dataclass
@ -1088,6 +1107,7 @@ class TEModel(Enum):
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
QWEN3_2B = 17
JINA_CLIP_2 = 18
def detect_te_model(sd):
@ -1097,6 +1117,8 @@ def detect_te_model(sd):
return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
@ -1259,6 +1281,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_2B:
clip_target.clip = ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = ovis.OvisTokenizer
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = jina_clip_2.JinaClip2TokenizerWrapper
else:
# clip_l
if clip_type == CLIPType.SD3:
@ -1314,6 +1339,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = newbie.NewBieTokenizer
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
clip_data_gemma = clip_data[0]
clip_data_jina = clip_data[1]
else:
clip_data_gemma = clip_data[1]
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Tuple, Sequence, TypeVar, Callable, Optional, Union
import torch
try:
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
except ImportError:
@ -557,7 +558,7 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
class SDTokenizer:
def __init__(self, tokenizer_path: Optional[Union[torch.Tensor, bytes, bytearray, memoryview, str, Path, Traversable]] = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data=None, tokenizer_args=None):
def __init__(self, tokenizer_path: Optional[Union[torch.Tensor, bytes, bytearray, memoryview, str, Path, Traversable]] = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data=None, tokenizer_args=None):
if tokenizer_data is None:
tokenizer_data = dict()
if tokenizer_args is None:
@ -617,6 +618,8 @@ class SDTokenizer:
self.embedding_size = embedding_size
self.embedding_key = embedding_key
self.disable_weights = disable_weights
def clone(self) -> SDTokenizerT:
sd_tokenizer = copy.copy(self)
# correctly copy additional vocab
@ -665,7 +668,7 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
if kwargs.get("disable_weights", False):
if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)]
else:
parsed_weights = token_weights(text, 1.0)

View File

@ -0,0 +1,219 @@
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
from dataclasses import dataclass
import torch
from torch import nn as nn
from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
vocab_size: int = 250002
type_vocab_size: int = 1
hidden_size: int = 1024
num_hidden_layers: int = 24
num_attention_heads: int = 16
rotary_emb_base: float = 20000.0
intermediate_size: int = 4096
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-05
bos_token_id: int = 0
eos_token_id: int = 2
pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
def forward(self, input_ids=None, embeddings=None):
if input_ids is not None and embeddings is None:
embeddings = self.word_embeddings(input_ids)
if embeddings is not None:
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, q, k):
batch, seqlen, heads, head_dim = q.shape
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
def rotate_half(x):
size = x.shape[-1] // 2
x1, x2 = x[..., :size], x[..., size:]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = embed_dim // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None, optimized_attention=None):
qkv = self.Wqkv(x)
batch_size, seq_len, _ = qkv.shape
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k = self.rotary_emb(q, k)
# NHD -> HND
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out)
class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
self.activation = F.gelu
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
def forward(self, hidden_states, mask=None, optimized_attention=None):
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
mlp_out = self.mlp(hidden_states)
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states
class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states
class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x)
x = self.emb_drop(x)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
sequence_output = self.encoder(x, attention_mask=mask)
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
pooled_output = None
if attention_mask is None:
pooled_output = sequence_output.mean(dim=1)
else:
attention_mask = attention_mask.to(sequence_output.dtype)
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
# Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.config = XLMRobertaConfig(**config_dict)
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
def get_input_embeddings(self):
return self.model.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.model.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@ -185,7 +185,7 @@ class Gemma3_4B_Config:
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
@ -194,8 +194,8 @@ class Gemma3_4B_Config:
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
@ -381,7 +381,7 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
if config.sliding_attention is not None:
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
@ -398,7 +398,12 @@ class TransformerBlockGemma2(nn.Module):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logger.warning("Warning: sliding attention not implemented, results may be incorrect")
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
else:
attention_mask = sliding_mask
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]

View File

@ -19,7 +19,7 @@ class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_directory=None, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_directory=None, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@ -49,8 +49,12 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, textmodel_json_config=None):
if textmodel_json_config is None:
textmodel_json_config = {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None, name="gemma2_2b", clip_model=Gemma2_2BModel):

View File

@ -0,0 +1,62 @@
import torch
import comfy.model_management
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.lumina2
class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
raise NotImplementedError
def state_dict(self):
return {}
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
self.gemma.set_clip_options(options)
self.jina.set_clip_options(options)
def reset_clip_options(self):
self.gemma.reset_clip_options()
self.jina.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_gemma = token_weight_pairs["gemma"]
token_weight_pairs_jina = token_weight_pairs["jina"]
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
return gemma_out, jina_pooled, gemma_extra
def load_sd(self, sd):
if "model.layers.0.self_attn.q_norm.weight" in sd:
return self.gemma.load_sd(sd)
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_

View File

@ -1558,12 +1558,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None)
return type_clone
@final

View File

@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel):
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
watermark: bool | None = Field(False)
class Image2ImageTaskCreationRequest(BaseModel):
@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel):
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
watermark: bool | None = Field(False)
class Seedream4Options(BaseModel):
@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel):
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
watermark: bool = Field(False)
class ImageTaskCreationResponse(BaseModel):

View File

@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
uploadImagesToStorage: bool = Field(True)
class GeminiGenerateContentRequest(BaseModel):

View File

@ -0,0 +1,52 @@
from pydantic import BaseModel, Field
class Datum2(BaseModel):
b64_json: str | None = Field(None, description="Base64 encoded image data")
revised_prompt: str | None = Field(None, description="Revised prompt")
url: str | None = Field(None, description="URL of the image")
class InputTokensDetails(BaseModel):
image_tokens: int | None = None
text_tokens: int | None = None
class Usage(BaseModel):
input_tokens: int | None = None
input_tokens_details: InputTokensDetails | None = None
output_tokens: int | None = None
total_tokens: int | None = None
class OpenAIImageGenerationResponse(BaseModel):
data: list[Datum2] | None = None
usage: Usage | None = None
class OpenAIImageEditRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str = Field(...)
moderation: str | None = Field(None)
n: int | None = Field(None, description="The number of images to generate")
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
size: str | None = Field(None, description="Size of the output image")
class OpenAIImageGenerationRequest(BaseModel):
background: str | None = Field(None, description="Background transparency")
model: str | None = Field(None)
moderation: str | None = Field(None)
n: int | None = Field(
None,
description="The number of images to generate.",
)
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
output_format: str | None = Field(None)
prompt: str = Field(...)
quality: str | None = Field(None, description="The quality of the generated image")
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")

View File

@ -1,10 +1,8 @@
from inspect import cleandoc
import torch
from pydantic import BaseModel
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
)
def convert_mask_to_image(mask: torch.Tensor):
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[
IO.String.Input(
"prompt",
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False,
raw: bool = False,
seed: int = 0,
image_prompt: torch.Tensor | None = None,
image_prompt: Input.Image | None = None,
image_prompt_strength: float = 0.1,
) -> IO.NodeOutput:
if image_prompt is None:
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
class FluxKontextProImageNode(IO.ComfyNode):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
"prompt",
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str,
guidance: float,
steps: int,
input_image: torch.Tensor | None = None,
input_image: Input.Image | None = None,
seed=0,
prompt_upsampling=False,
) -> IO.NodeOutput:
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "")
DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
NODE_ID = "FluxKontextMaxImageNode"
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProExpandNode(IO.ComfyNode):
"""
Outpaints image based on prompt.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Outpaints image based on prompt.",
inputs=[
IO.Image.Input("image"),
IO.String.Input(
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
prompt_upsampling: bool,
top: int,
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
class FluxProFillNode(IO.ComfyNode):
"""
Inpaints image based on mask and prompt.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
node_id="FluxProFillNode",
display_name="Flux.1 Fill Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
description="Inpaints image based on mask and prompt.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask"),
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: torch.Tensor,
mask: torch.Tensor,
image: Input.Image,
mask: Input.Image,
prompt: str,
prompt_upsampling: bool,
steps: int,
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Flux2ProImageNode",
display_name="Flux.2 [pro] Image",
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
default=True,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
"If active, automatically modifies the prompt for more creative generation.",
),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
],
outputs=[IO.Image.Output()],
hidden=[
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
height: int,
seed: int,
prompt_upsampling: bool,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
) -> IO.NodeOutput:
reference_images = {}
if images is not None:
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest(
prompt=prompt,
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
class BFLExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
FluxProExpandNode,
FluxProFillNode,
Flux2ProImageNode,
Flux2MaxImageNode,
]

View File

@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image',
optional=True,
),
@ -346,7 +346,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True,
),
@ -380,7 +380,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
sequential_image_generation: str = "disabled",
max_images: int = 1,
seed: int = 0,
watermark: bool = True,
watermark: bool = False,
fail_on_partial: bool = True,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -507,7 +507,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -617,7 +617,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -739,7 +739,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),
@ -862,7 +862,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True,
),

View File

@ -34,6 +34,7 @@ from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
)
parts = []
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
if part_type == "text" and part.text:
parts.append(part)
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
return parts
@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
else:
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4))
@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode):
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode):
@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):

View File

@ -858,7 +858,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("duration", options=["5", "10"]),
IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
@ -897,6 +897,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
if duration not in (5, 10) and end_frame is None and reference_images is None:
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [

View File

@ -1,42 +1,44 @@
from io import BytesIO
import base64
import os
from enum import Enum
from inspect import cleandoc
from io import BytesIO
import numpy as np
import torch
from PIL import Image
from comfy.cmd import folder_paths
import base64
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from typing_extensions import override
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
OpenAICreateResponse,
OpenAIResponse,
CreateModelResponseProperties,
Item,
OutputContent,
InputImageContent,
Detail,
InputTextContent,
InputMessage,
InputMessageContentList,
InputContent,
InputFileContent,
InputImageContent,
InputMessage,
InputMessageContentList,
InputTextContent,
Item,
OpenAICreateResponse,
OpenAIResponse,
OutputContent,
)
from comfy_api_nodes.apis.openai_api import (
OpenAIImageEditRequest,
OpenAIImageGenerationRequest,
OpenAIImageGenerationResponse,
)
from comfy_api_nodes.util import (
downscale_image_tensor,
download_url_to_bytesio,
validate_string,
tensor_to_base64_string,
ApiEndpoint,
sync_op,
download_url_to_bytesio,
downscale_image_tensor,
poll_op,
sync_op,
tensor_to_base64_string,
text_filepath_to_data_uri,
validate_string,
)
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
@ -96,9 +98,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
class OpenAIDalle2(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
"""
@classmethod
def define_schema(cls):
@ -106,7 +105,7 @@ class OpenAIDalle2(IO.ComfyNode):
node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[
IO.String.Input(
"prompt",
@ -232,9 +231,6 @@ class OpenAIDalle2(IO.ComfyNode):
class OpenAIDalle3(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
"""
@classmethod
def define_schema(cls):
@ -242,7 +238,7 @@ class OpenAIDalle3(IO.ComfyNode):
node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[
IO.String.Input(
"prompt",
@ -324,10 +320,16 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.NodeOutput(await validate_and_cast_response(response))
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
# https://platform.openai.com/docs/pricing
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode):
"""
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
"""
@classmethod
def define_schema(cls):
@ -335,13 +337,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1",
category="api node/image/OpenAI",
description=cleandoc(cls.__doc__ or ""),
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt for GPT Image 1",
tooltip="Text prompt for GPT Image",
),
IO.Int.Input(
"seed",
@ -363,8 +365,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
),
IO.Combo.Input(
"background",
default="opaque",
options=["opaque", "transparent"],
default="auto",
options=["auto", "opaque", "transparent"],
tooltip="Return image with or without background",
optional=True,
),
@ -395,6 +397,11 @@ class OpenAIGPTImage1(IO.ComfyNode):
tooltip="Optional mask for inpainting (white areas will be replaced)",
optional=True,
),
IO.Combo.Input(
"model",
options=["gpt-image-1", "gpt-image-1.5"],
optional=True,
),
],
outputs=[
IO.Image.Output(),
@ -410,32 +417,34 @@ class OpenAIGPTImage1(IO.ComfyNode):
@classmethod
async def execute(
cls,
prompt,
seed=0,
quality="low",
background="opaque",
image=None,
mask=None,
n=1,
size="1024x1024",
prompt: str,
seed: int = 0,
quality: str = "low",
background: str = "opaque",
image: Input.Image | None = None,
mask: Input.Image | None = None,
n: int = 1,
size: str = "1024x1024",
model: str = "gpt-image-1",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
files = []
if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image")
if model == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1
elif model == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5
else:
raise ValueError(f"Unknown model: {model}")
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
content_type = "multipart/form-data"
files = []
batch_size = image.shape[0]
for i in range(batch_size):
single_image = image[i: i + 1]
scaled_image = downscale_image_tensor(single_image).squeeze()
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
@ -448,44 +457,59 @@ class OpenAIGPTImage1(IO.ComfyNode):
else:
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if image is None:
raise Exception("Cannot use a mask without an input image")
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
batch, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
if mask is not None:
if image.shape[0] != 1:
raise Exception("Cannot use a mask with multiple image")
if mask.shape[1:] != image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
_, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
# Build the operation
response = await sync_op(
cls,
ApiEndpoint(path=path, method="POST"),
response_model=OpenAIImageGenerationResponse,
data=request_class(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
),
files=files if files else None,
content_type=content_type,
)
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageEditRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
content_type="multipart/form-data",
files=files,
price_extractor=price_extractor,
)
else:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageGenerationRequest(
model=model,
prompt=prompt,
quality=quality,
background=background,
n=n,
seed=seed,
size=size,
moderation="low",
),
price_extractor=price_extractor,
)
return IO.NodeOutput(await validate_and_cast_response(response))

View File

@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
}
UPSCALER_VALUES_MAP = {
"FullHD (1080p)": 1920,
"4K (2160p)": 3840,
}
class TopazImageEnhance(IO.ComfyNode):
@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"upscaler_creativity",
options=["low", "middle", "high"],
@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode):
target_frame_rate = src_frame_rate
filters = []
if upscaler_enabled:
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
if "1080p" in upscaler_resolution:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
filters.append(
topaz_api.VideoEnhancementFilter(
model=UPSCALER_MODELS_MAP[upscaler_model],

View File

@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel):
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
watermark: bool = Field(False)
class Image2ImageParametersField(BaseModel):
size: str | None = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True)
watermark: bool = Field(False)
class Text2VideoParametersField(BaseModel):
@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel):
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel):
seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True)
watermark: bool = Field(True)
watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single")
@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode):
height: int = 1024,
seed: int = 0,
prompt_extend: bool = True,
watermark: bool = True,
watermark: bool = False,
):
initial_response = await sync_op(
cls,
@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode):
# width: int = 1024,
# height: int = 1024,
seed: int = 0,
watermark: bool = True,
watermark: bool = False,
):
n_images = get_number_of_images(image)
if n_images not in (1, 2):
@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode):
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
watermark: bool = False,
shot_type: str = "single",
):
if "480p" in size and model == "wan2.6-t2v":
@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode):
),
IO.Boolean.Input(
"watermark",
default=True,
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
optional=True,
),
@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode):
seed: int = 0,
generate_audio: bool = False,
prompt_extend: bool = True,
watermark: bool = True,
watermark: bool = False,
shot_type: str = "single",
):
if get_number_of_images(image) != 1:

View File

@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
return img_byte_arr
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)

291
comfy_execution/jobs.py Normal file
View File

@ -0,0 +1,291 @@
"""
Job utilities for the /api/jobs endpoint.
Provides normalization and helper functions for job status tracking.
"""
from typing import Optional
from comfy_api.internal import prune_dict
class JobStatus:
"""Job status constants."""
PENDING = 'pending'
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
FAILED = 'failed'
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
Returns:
tuple: (create_time, workflow_id)
"""
create_time = extra_data.get('create_time')
extra_pnginfo = extra_data.get('extra_pnginfo', {})
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
return create_time, workflow_id
def is_previewable(media_type: str, item: dict) -> bool:
"""
Check if an output item is previewable.
Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts
Maintains backwards compatibility with existing logic.
Priority:
1. media_type is 'images', 'video', or 'audio'
2. format field starts with 'video/' or 'audio/'
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
# Check format field (MIME type).
# Maintains backwards compatibility with how custom node outputs are handled in the frontend.
fmt = item.get('format', '')
if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')):
return True
# Check for 3D files by extension
filename = item.get('filename', '').lower()
if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS):
return True
return False
def normalize_queue_item(item: tuple, status: str) -> dict:
"""Convert queue item tuple to unified job dict.
Expects item with sensitive data already removed (5 elements).
"""
priority, prompt_id, _, extra_data, _ = item
create_time, workflow_id = _extract_job_metadata(extra_data)
return prune_dict({
'id': prompt_id,
'status': status,
'priority': priority,
'create_time': create_time,
'outputs_count': 0,
'workflow_id': workflow_id,
})
def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict:
"""Convert history item dict to unified job dict.
History items have sensitive data already removed (prompt tuple has 5 elements).
"""
prompt_tuple = history_item['prompt']
priority, _, prompt, extra_data, _ = prompt_tuple
create_time, workflow_id = _extract_job_metadata(extra_data)
status_info = history_item.get('status', {})
status_str = status_info.get('status_str') if status_info else None
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.FAILED
else:
status = JobStatus.COMPLETED
outputs = history_item.get('outputs', {})
outputs_count, preview_output = get_outputs_summary(outputs)
execution_error = None
execution_start_time = None
execution_end_time = None
if status_info:
messages = status_info.get('messages', [])
for entry in messages:
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
event_name, event_data = entry[0], entry[1]
if isinstance(event_data, dict):
if event_name == 'execution_start':
execution_start_time = event_data.get('timestamp')
elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'):
execution_end_time = event_data.get('timestamp')
if event_name == 'execution_error':
execution_error = event_data
job = prune_dict({
'id': prompt_id,
'status': status,
'priority': priority,
'create_time': create_time,
'execution_start_time': execution_start_time,
'execution_end_time': execution_end_time,
'execution_error': execution_error,
'outputs_count': outputs_count,
'preview_output': preview_output,
'workflow_id': workflow_id,
})
if include_outputs:
job['outputs'] = outputs
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
'extra_data': extra_data,
}
return job
def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
"""
Count outputs and find preview in a single pass.
Returns (outputs_count, preview_output).
Preview priority (matching frontend):
1. type="output" with previewable media
2. Any previewable media
"""
count = 0
preview_output = None
fallback_preview = None
for node_id, node_outputs in outputs.items():
if not isinstance(node_outputs, dict):
continue
for media_type, items in node_outputs.items():
# 'animated' is a boolean flag, not actual output items
if media_type == 'animated' or not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
count += 1
if preview_output is None and is_previewable(media_type, item):
enriched = {
**item,
'nodeId': node_id,
'mediaType': media_type
}
if item.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched
return count, preview_output or fallback_preview
def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]:
"""Sort jobs list by specified field and order."""
reverse = (sort_order == 'desc')
if sort_by == 'execution_duration':
def get_sort_key(job):
start = job.get('execution_start_time', 0)
end = job.get('execution_end_time', 0)
return end - start if end and start else 0
else:
def get_sort_key(job):
return job.get('create_time', 0)
return sorted(jobs, key=get_sort_key, reverse=reverse)
def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]:
"""
Get a single job by prompt_id from history or queue.
Args:
prompt_id: The prompt ID to look up
running: List of currently running queue items
queued: List of pending queue items
history: Dict of history items keyed by prompt_id
Returns:
Job dict with full details, or None if not found
"""
if prompt_id in history:
return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True)
for item in running:
if item[1] == prompt_id:
return normalize_queue_item(item, JobStatus.IN_PROGRESS)
for item in queued:
if item[1] == prompt_id:
return normalize_queue_item(item, JobStatus.PENDING)
return None
def get_all_jobs(
running: list,
queued: list,
history: dict,
status_filter: Optional[list[str]] = None,
workflow_id: Optional[str] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
limit: Optional[int] = None,
offset: int = 0
) -> tuple[list[dict], int]:
"""
Get all jobs (running, pending, completed) with filtering and sorting.
Args:
running: List of currently running queue items
queued: List of pending queue items
history: Dict of history items keyed by prompt_id
status_filter: List of statuses to include (from JobStatus.ALL)
workflow_id: Filter by workflow ID
sort_by: Field to sort by ('created_at', 'execution_duration')
sort_order: 'asc' or 'desc'
limit: Maximum number of items to return
offset: Number of items to skip
Returns:
tuple: (jobs_list, total_count)
"""
jobs = []
if status_filter is None:
status_filter = JobStatus.ALL
if JobStatus.IN_PROGRESS in status_filter:
for item in running:
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
if JobStatus.PENDING in status_filter:
for item in queued:
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
include_completed = JobStatus.COMPLETED in status_filter
include_failed = JobStatus.FAILED in status_filter
if include_completed or include_failed:
for prompt_id, history_item in history.items():
is_failed = history_item.get('status', {}).get('status_str') == 'error'
if (is_failed and include_failed) or (not is_failed and include_completed):
jobs.append(normalize_history_item(prompt_id, history_item))
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
jobs = apply_sorting(jobs, sort_by, sort_order)
total_count = len(jobs)
if offset > 0:
jobs = jobs[offset:]
if limit is not None:
jobs = jobs[:limit]
return (jobs, total_count)

View File

@ -11,6 +11,7 @@ import comfy.utils
from comfy import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import re
class BasicScheduler(io.ComfyNode):
@ -762,8 +763,12 @@ class SamplerCustom(io.ComfyNode):
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
if samples.is_nested:
latent_shapes = [x.shape for x in samples.unbind()]
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
out_denoised = latent.copy()
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
out_denoised["samples"] = x0_out
else:
out_denoised = out
return io.NodeOutput(out, out_denoised)
@ -950,8 +955,12 @@ class SamplerCustomAdvanced(io.ComfyNode):
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
if samples.is_nested:
latent_shapes = [x.shape for x in samples.unbind()]
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
out_denoised = latent.copy()
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
out_denoised["samples"] = x0_out
else:
out_denoised = out
return io.NodeOutput(out, out_denoised)
@ -1007,6 +1016,25 @@ class AddNoise(io.ComfyNode):
add_noise = execute
class ManualSigmas(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ManualSigmas",
category="_for_testing/custom_sampling",
is_experimental=True,
inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False)
],
outputs=[io.Sigmas.Output()]
)
@classmethod
def execute(cls, sigmas) -> io.NodeOutput:
sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas)
sigmas = [float(i) for i in sigmas]
sigmas = torch.FloatTensor(sigmas)
return io.NodeOutput(sigmas)
class CustomSamplersExtension(ComfyExtension):
@override
@ -1046,6 +1074,7 @@ class CustomSamplersExtension(ComfyExtension):
DisableNoise,
AddNoise,
SamplerCustomAdvanced,
ManualSigmas,
]

View File

@ -1126,6 +1126,99 @@ class MergeTextListsNode(TextProcessingNode):
# ========== Training Dataset Nodes ==========
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
display_name="Resolution Bucket",
category="dataset",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts to bucket by resolution.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists (must match latents length).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of batched latent dicts, one per resolution bucket.",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of condition lists, one per resolution bucket.",
),
],
)
@classmethod
def execute(cls, latents, conditioning):
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
# conditioning: list[list[cond]]
# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)
# Flatten latents and conditions to individual samples
flat_latents = [] # list of (C, H, W) tensors
flat_conditions = [] # list of condition lists
for latent_dict, cond in zip(latents, conditioning):
samples = latent_dict["samples"] # (B, C, H, W)
batch_size = samples.shape[0]
# cond is a list of conditions with length == batch_size
for i in range(batch_size):
flat_latents.append(samples[i]) # (C, H, W)
flat_conditions.append(cond[i]) # single condition
# Group by resolution (H, W)
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
h, w = latent.shape[-2], latent.shape[-1]
key = (h, w)
if key not in buckets:
buckets[key] = {"latents": [], "conditions": []}
buckets[key]["latents"].append(latent)
buckets[key]["conditions"].append(cond)
# Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents})
# Conditions stay as list of condition lists
output_conditions.append(bucket_data["conditions"])
logging.info(
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
)
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
return io.NodeOutput(output_latents, output_conditions)
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@ -1374,7 +1467,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = torch.load(f, weights_only=True)
shard_data = torch.load(f)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])
@ -1426,6 +1519,7 @@ class DatasetExtension(ComfyExtension):
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
ResolutionBucket,
]

View File

@ -1,13 +1,15 @@
import logging
import math
import torch
from comfy.nodes import base_nodes as nodes
from typing_extensions import override
import comfy.utils
from comfy.component_model.tensor_types import Latent
from comfy.nodes.package_typing import Seed, Seed64
from .nodes_post_processing import gaussian_kernel
from typing_extensions import override
from comfy.nodes import base_nodes as nodes
from comfy_api.latest import ComfyExtension, io
import logging
from .nodes_post_processing import gaussian_kernel
logger = logging.getLogger(__name__)
def reshape_latent_to(target_shape, latent, repeat_batch=True):
@ -216,6 +218,47 @@ class LatentCut(io.ComfyNode):
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return io.NodeOutput(samples_out)
class LatentCutToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentCutToBatch",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["t", "x", "y"]),
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
if "x" in dim:
dim = s1.ndim - 1
elif "y" in dim:
dim = s1.ndim - 2
elif "t" in dim:
dim = s1.ndim - 3
if dim < 2:
return io.NodeOutput(samples)
s = s1.movedim(dim, 1)
if s.shape[1] < slice_size:
slice_size = s.shape[1]
elif s.shape[1] % slice_size != 0:
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
new_shape = [-1, slice_size] + list(s.shape[2:])
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
return io.NodeOutput(samples_out)
class LatentBatch(io.ComfyNode):
@classmethod
@ -500,6 +543,7 @@ class LatentExtension(ComfyExtension):
LatentInterpolate,
LatentConcat,
LatentCut,
LatentCutToBatch,
LatentBatch,
LatentBatchSeedBehavior,
LatentAddNoiseChannels,

View File

@ -354,7 +354,7 @@ class ZImageControlPatch:
if self.mask is None:
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
else:
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
if latent_image is None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))

View File

@ -228,6 +228,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
],
outputs=[
io.Image.Output(),
@ -235,15 +236,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
)
@classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
samples = image.movedim(-1, 1)
total = int(megapixels * 1024 * 1024)
total = megapixels * 1024 * 1024
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
s = utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
s = s.movedim(1, -1)
return io.NodeOutput(s)

View File

@ -3,7 +3,9 @@ import comfy.utils
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import nodes
class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod
@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
return io.NodeOutput(conditioning)
class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyQwenImageLayeredLatentImage",
display_name="Empty Qwen Image Layered Latent",
category="latent/qwen",
inputs=[
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class QwenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus,
EmptyQwenImageLayeredLatentImage,
]

View File

@ -11,6 +11,7 @@ from typing_extensions import override
import comfy.model_management
import comfy.samplers
import comfy.sampler_helpers
import comfy.sd
import comfy.utils
from comfy import node_helpers
@ -24,6 +25,69 @@ from .nodes_custom_sampler import Noise_RandomNoise, Guider_Basic
logger = logging.getLogger(__name__)
class TrainGuider(Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def outer_sample(
self,
noise,
latent_image,
sampler,
sigmas,
denoise_mask=None,
callback=None,
disable_pbar=False,
seed=None,
latent_shapes=None,
):
self.inner_model, self.conds, self.loaded_models = (
comfy.sampler_helpers.prepare_sampling(
self.model_patcher,
noise.shape,
self.conds,
self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
)
)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(
denoise_mask, noise.shape, device
)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
comfy.samplers.cast_to_load_options(
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
)
try:
self.model_patcher.pre_run()
output = self.inner_sample(
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
latent_shapes=latent_shapes,
)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.loaded_models
return output
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
@ -68,6 +132,7 @@ class TrainSampler(comfy.samplers.Sampler):
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -78,6 +143,28 @@ class TrainSampler(comfy.samplers.Sampler):
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def _init_bucket_data(self, bucket_latents):
"""Initialize bucket offsets and weights for sampling."""
self.bucket_offsets = [0]
bucket_sizes = []
for lat in bucket_latents:
bucket_sizes.append(lat.shape[0])
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
def fwd_bwd(
self,
@ -118,6 +205,109 @@ class TrainSampler(comfy.samplers.Sampler):
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
"""Generate random sigma values for a batch."""
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(batch_size)
]
return torch.tensor(batch_sigmas).to(device)
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
"""Execute one training step in bucket mode."""
# Sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0]
bucket_offset = self.bucket_offsets[bucket_idx]
# Sample indices from this bucket (use all if bucket_size < batch_size)
actual_batch_size = min(self.batch_size, bucket_size)
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond, # Use flattened cond with absolute indices
absolute_indices,
extra_args,
self.num_images,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
# todo: should this be "0" or scalar_tensor?
total_loss = torch.tensor(0.0)
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
def sample(
self,
model_wrap,
@ -139,77 +329,24 @@ class TrainSampler(comfy.samplers.Sampler):
self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not current_execution_context().server.receive_all_progress_notifications
disable=not current_execution_context().server.receive_all_progress_notifications,
)
):
noisegen = Noise_RandomNoise(
self.seed + i * 1000
)
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
if self.real_dataset is None:
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
if self.bucket_latents is not None:
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
elif self.real_dataset is None:
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
else:
# todo: should this be "0" or scalar_tensor?
total_loss = torch.tensor(0.0)
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
ui_pbar.update(1)
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
@ -287,6 +424,364 @@ def unpatch(m):
del m.org_forward
def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training.
Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
Returns:
list of latent tensors
"""
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents
def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training.
Args:
latents: list of latent dicts or single latent dict
Returns:
Processed latents (tensor or list of tensors)
"""
if len(latents) == 1:
return latents[0]["samples"] # Single latent dict
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
return latent_list
def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists.
Args:
positive: list of conditioning
Returns:
Flattened conditioning list
"""
if len(positive) == 1:
return positive[0] # Single conditioning list
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
return flat_positive
def _prepare_latents_and_count(latents, dtype, bucket_mode):
"""Convert latents to dtype and compute image counts.
Args:
latents: Latents (tensor, list of tensors, or bucket list)
dtype: Target dtype
bucket_mode: Whether bucket mode is enabled
Returns:
tuple: (processed_latents, num_images, multi_res)
"""
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
multi_res = False
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_images = 0
multi_res = False
return latents, num_images, multi_res
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
"""Validate conditioning count matches image count, expand if needed.
Args:
positive: Conditioning list
num_images: Number of images
bucket_mode: Whether bucket mode is enabled
Returns:
Validated/expanded conditioning list
Raises:
ValueError: If conditioning count doesn't match image count
"""
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
return positive
def _load_existing_lora(existing_lora):
"""Load existing LoRA weights if provided.
Args:
existing_lora: LoRA filename or "[None]"
Returns:
tuple: (existing_weights dict, existing_steps int)
"""
if existing_lora == "[None]":
return {}, 0
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
existing_weights = {}
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
return existing_weights, existing_steps
def _create_weight_adapter(
module, module_name, existing_weights, algorithm, lora_dtype, rank
):
"""Create a weight adapter for a module with weight.
Args:
module: The module to create adapter for
module_name: Name of the module
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (train_adapter, lora_params dict)
"""
key = f"{module_name}.weight"
shape = module.weight.shape
lora_params = {}
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
# Try to load existing adapter
existing_adapter = None
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
module_name, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
if existing_adapter is None:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
module.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter.train().requires_grad_(True), lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff).train().requires_grad_(True)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
def _create_bias_adapter(module, module_name, lora_dtype):
"""Create a bias adapter for a module with bias.
Args:
module: The module with bias
module_name: Name of the module
lora_dtype: dtype for LoRA weights
Returns:
tuple: (bias_module, lora_params dict)
"""
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias).train().requires_grad_(True)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup all LoRA adapters on the model.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list)
"""
lora_sd = {}
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
key = f"{n}.weight"
mp.add_weight_wrapper(key, adapter)
all_weight_adapters.append(adapter)
if hasattr(m, "bias") and m.bias is not None:
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
return lora_sd, all_weight_adapters
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
Args:
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
parameters: Parameters to optimize
learning_rate: Learning rate
Returns:
Optimizer instance
"""
if optimizer_name == "Adam":
return torch.optim.Adam(parameters, lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(parameters, lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(parameters, lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(parameters, lr=learning_rate)
def _create_loss_function(loss_function_name):
"""Create loss function based on name.
Args:
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
Returns:
Loss function instance
"""
if loss_function_name == "MSE":
return torch.nn.MSELoss()
elif loss_function_name == "L1":
return torch.nn.L1Loss()
elif loss_function_name == "Huber":
return torch.nn.HuberLoss()
elif loss_function_name == "SmoothL1":
return torch.nn.SmoothL1Loss()
def _run_training_loop(
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
):
"""Execute the training loop.
Args:
guider: The guider object
train_sampler: The training sampler
latents: Latent tensors
num_images: Number of images
seed: Random seed
bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled
"""
sigmas = torch.tensor(range(num_images))
noise = Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -389,6 +884,11 @@ class TrainLoraNode(io.ComfyNode):
default="[None]",
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
io.Boolean.Input(
"bucket_mode",
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
],
outputs=[
io.Model.Output(
@ -423,6 +923,7 @@ class TrainLoraNode(io.ComfyNode):
algorithm,
gradient_checkpointing,
existing_lora,
bucket_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -431,217 +932,125 @@ class TrainLoraNode(io.ComfyNode):
grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0]
rank = rank[0]
optimizer = optimizer[0]
loss_function = loss_function[0]
optimizer_name = optimizer[0]
loss_function_name = loss_function[0]
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
# Handle latents - either single dict or list of dicts
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
latents = _process_latents_standard_mode(latents)
# Handle conditioning - either single list or list of lists
if len(positive) == 1:
positive = positive[0] # Single conditioning list
else:
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
positive = flat_positive
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
# latents here can be list of different size latent or one large batch
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logger.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
multi_res = False
latents = latents.to(dtype)
num_images = latents.shape[0]
else:
raise ValueError(f"Invalid latents type: {type(latents)}")
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
)
logger.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Setup models for training
mp.model.requires_grad_(False)
# Load existing LoRA weights if provided
existing_weights = {}
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
existing_weights, existing_steps = _load_existing_lora(existing_lora)
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
else:
existing_adapter = None
adapter_cls = adapter_maps[algorithm]
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(
lora_dtype
)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
# Create optimizer and loss function
optimizer = _create_optimizer(
optimizer_name, lora_sd.values(), learning_rate
)
criterion = _create_loss_function(loss_function_name)
mp.add_weight_wrapper(key, train_adapter)
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(
m.bias.shape, dtype=lora_dtype, requires_grad=True
)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
else:
criterion = None
# setup models
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
)
# Setup sampler and guider like in test script
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=int(1e20), force_full_load=True
)
torch.cuda.empty_cache()
# Setup loss tracking
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
guider = Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
# Create sampler
if bucket_mode:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
)
else:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
# Training loop
# Setup guider
guider = TrainGuider(mp)
guider.set_conds(positive)
# Run training loop
try:
# Generate dummy sigmas and noise
sigmas = torch.tensor(range(num_images))
noise = Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
_run_training_loop(
guider,
train_sampler,
sigmas,
seed=noise.seed,
latents,
num_images,
seed,
bucket_mode,
multi_res,
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
@ -651,7 +1060,7 @@ class TrainLoraNode(io.ComfyNode):
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):
class LoraModelLoader(io.ComfyNode): #
@classmethod
def define_schema(cls):
return io.Schema(

View File

@ -1,6 +1,6 @@
[project]
name = "comfyui"
version = "0.5.0"
version = "0.6.0"
description = "An installable version of ComfyUI"
readme = "README.md"
authors = [
@ -18,8 +18,8 @@ classifiers = [
]
dependencies = [
"comfyui-frontend-package>=1.33.13",
"comfyui-workflow-templates>=0.7.54",
"comfyui-frontend-package>=1.35.9",
"comfyui-workflow-templates>=0.7.64",
"comfyui-embedded-docs>=0.3.1",
"torch",
"torchvision",
@ -199,7 +199,7 @@ comfyui-manager = [
"chardet",
"pip",
# todo: bold move
# "comfyui_manager==4.0.3b4",
# "comfyui_manager==4.0.3b7",
]
[project.scripts]

View File

@ -843,3 +843,106 @@ class TestExecution:
# Check output
result_image = result.get_images(output)[0]
assert numpy.array(result_image).mean() == 0, "Image should be black"
# Jobs API tests
def test_jobs_api_job_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that job objects have required fields"""
self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
assert len(jobs_response["jobs"]) > 0, "Should have at least one job"
job = jobs_response["jobs"][0]
assert "id" in job, "Job should have id"
assert "status" in job, "Job should have status"
assert "create_time" in job, "Job should have create_time"
assert "outputs_count" in job, "Job should have outputs_count"
assert "preview_output" in job, "Job should have preview_output"
def test_jobs_api_preview_output_structure(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test that preview_output has correct structure"""
self._create_history_item(client, builder)
jobs_response = client.get_jobs(status="completed", limit=1)
job = jobs_response["jobs"][0]
if job["preview_output"] is not None:
preview = job["preview_output"]
assert "filename" in preview, "Preview should have filename"
assert "nodeId" in preview, "Preview should have nodeId"
assert "mediaType" in preview, "Preview should have mediaType"
def test_jobs_api_pagination(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API pagination"""
for _ in range(5):
self._create_history_item(client, builder)
first_page = client.get_jobs(limit=2, offset=0)
second_page = client.get_jobs(limit=2, offset=2)
assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs"
assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs"
first_ids = {j["id"] for j in first_page["jobs"]}
second_ids = {j["id"] for j in second_page["jobs"]}
assert first_ids.isdisjoint(second_ids), "Pages should have different jobs"
def test_jobs_api_sorting(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API sorting"""
for _ in range(3):
self._create_history_item(client, builder)
desc_jobs = client.get_jobs(sort_order="desc")
asc_jobs = client.get_jobs(sort_order="asc")
if len(desc_jobs["jobs"]) >= 2:
desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]]
asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]]
if len(desc_times) >= 2:
assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first"
if len(asc_times) >= 2:
assert asc_times == sorted(asc_times), "Asc should be oldest first"
def test_jobs_api_status_filter(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test jobs API status filtering"""
self._create_history_item(client, builder)
completed_jobs = client.get_jobs(status="completed")
assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history"
for job in completed_jobs["jobs"]:
assert job["status"] == "completed", "Should only return completed jobs"
# Pending jobs are transient - just verify filter doesn't error
pending_jobs = client.get_jobs(status="pending")
for job in pending_jobs["jobs"]:
assert job["status"] == "pending", "Should only return pending jobs"
def test_get_job_by_id(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test getting a single job by ID"""
result = self._create_history_item(client, builder)
prompt_id = result.get_prompt_id()
job = client.get_job(prompt_id)
assert job is not None, "Should find the job"
assert job["id"] == prompt_id, "Job ID should match"
assert "outputs" in job, "Single job should include outputs"
def test_get_job_not_found(
self, client: ComfyClient, builder: GraphBuilder
):
"""Test getting a non-existent job returns 404"""
job = client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"

View File

@ -0,0 +1,361 @@
"""Unit tests for comfy_execution/jobs.py"""
from comfy_execution.jobs import (
JobStatus,
is_previewable,
normalize_queue_item,
normalize_history_item,
get_outputs_summary,
apply_sorting,
)
class TestJobStatus:
"""Test JobStatus constants."""
def test_status_values(self):
"""Status constants should have expected string values."""
assert JobStatus.PENDING == 'pending'
assert JobStatus.IN_PROGRESS == 'in_progress'
assert JobStatus.COMPLETED == 'completed'
assert JobStatus.FAILED == 'failed'
def test_all_contains_all_statuses(self):
"""ALL should contain all status values."""
assert JobStatus.PENDING in JobStatus.ALL
assert JobStatus.IN_PROGRESS in JobStatus.ALL
assert JobStatus.COMPLETED in JobStatus.ALL
assert JobStatus.FAILED in JobStatus.ALL
assert len(JobStatus.ALL) == 4
class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
"""Images, video, audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
"""Other media types should not be previewable."""
for media_type in ['latents', 'text', 'metadata', 'files']:
assert is_previewable(media_type, {}) is False
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
def test_3d_extensions_case_insensitive(self):
"""3D extension check should be case insensitive."""
item = {'filename': 'MODEL.GLB'}
assert is_previewable('files', item) is True
def test_video_format_previewable(self):
"""Items with video/ format should be previewable."""
item = {'format': 'video/mp4'}
assert is_previewable('files', item) is True
def test_audio_format_previewable(self):
"""Items with audio/ format should be previewable."""
item = {'format': 'audio/wav'}
assert is_previewable('files', item) is True
def test_other_format_not_previewable(self):
"""Items with other format should not be previewable."""
item = {'format': 'application/json'}
assert is_previewable('files', item) is False
class TestGetOutputsSummary:
"""Unit tests for get_outputs_summary()"""
def test_empty_outputs(self):
"""Empty outputs should return 0 count and None preview."""
count, preview = get_outputs_summary({})
assert count == 0
assert preview is None
def test_counts_across_multiple_nodes(self):
"""Outputs from multiple nodes should all be counted."""
outputs = {
'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]},
'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]},
'node3': {'images': [
{'filename': 'c.png', 'type': 'output'},
{'filename': 'd.png', 'type': 'output'}
]}
}
count, preview = get_outputs_summary(outputs)
assert count == 4
def test_skips_animated_key_and_non_list_values(self):
"""The 'animated' key and non-list values should be skipped."""
outputs = {
'node1': {
'images': [{'filename': 'test.png', 'type': 'output'}],
'animated': [True], # Should skip due to key name
'metadata': 'string', # Should skip due to non-list
'count': 42 # Should skip due to non-list
}
}
count, preview = get_outputs_summary(outputs)
assert count == 1
def test_preview_prefers_type_output(self):
"""Items with type='output' should be preferred for preview."""
outputs = {
'node1': {
'images': [
{'filename': 'temp.png', 'type': 'temp'},
{'filename': 'output.png', 'type': 'output'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 2
assert preview['filename'] == 'output.png'
def test_preview_fallback_when_no_output_type(self):
"""If no type='output', should use first previewable."""
outputs = {
'node1': {
'images': [
{'filename': 'temp1.png', 'type': 'temp'},
{'filename': 'temp2.png', 'type': 'temp'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert preview['filename'] == 'temp1.png'
def test_non_previewable_media_types_counted_but_no_preview(self):
"""Non-previewable media types should be counted but not used as preview."""
outputs = {
'node1': {
'latents': [
{'filename': 'latent1.safetensors'},
{'filename': 'latent2.safetensors'}
]
}
}
count, preview = get_outputs_summary(outputs)
assert count == 2
assert preview is None
def test_previewable_media_types(self):
"""Images, video, and audio media types should be previewable."""
for media_type in ['images', 'video', 'audio']:
outputs = {
'node1': {
media_type: [{'filename': 'test.file', 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"{media_type} should be previewable"
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"3D file {ext} should be previewable"
def test_format_mime_type_previewable(self):
"""Files with video/ or audio/ format should be previewable."""
for fmt in ['video/x-custom', 'audio/x-custom']:
outputs = {
'node1': {
'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview is not None, f"Format {fmt} should be previewable"
def test_preview_enriched_with_node_metadata(self):
"""Preview should include nodeId, mediaType, and original fields."""
outputs = {
'node123': {
'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}]
}
}
count, preview = get_outputs_summary(outputs)
assert preview['nodeId'] == 'node123'
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
class TestApplySorting:
"""Unit tests for apply_sorting()"""
def test_sort_by_create_time_desc(self):
"""Default sort by create_time descending."""
jobs = [
{'id': 'a', 'create_time': 100},
{'id': 'b', 'create_time': 300},
{'id': 'c', 'create_time': 200},
]
result = apply_sorting(jobs, 'created_at', 'desc')
assert [j['id'] for j in result] == ['b', 'c', 'a']
def test_sort_by_create_time_asc(self):
"""Sort by create_time ascending."""
jobs = [
{'id': 'a', 'create_time': 100},
{'id': 'b', 'create_time': 300},
{'id': 'c', 'create_time': 200},
]
result = apply_sorting(jobs, 'created_at', 'asc')
assert [j['id'] for j in result] == ['a', 'c', 'b']
def test_sort_by_execution_duration(self):
"""Sort by execution_duration should order by duration."""
jobs = [
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s
{'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s
]
result = apply_sorting(jobs, 'execution_duration', 'desc')
assert [j['id'] for j in result] == ['a', 'c', 'b']
def test_sort_with_none_values(self):
"""Jobs with None values should sort as 0."""
jobs = [
{'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100},
{'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None},
{'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200},
]
result = apply_sorting(jobs, 'execution_duration', 'asc')
assert result[0]['id'] == 'b' # None treated as 0, comes first
class TestNormalizeQueueItem:
"""Unit tests for normalize_queue_item()"""
def test_basic_normalization(self):
"""Queue item should be normalized to job dict."""
item = (
10, # priority/number
'prompt-123', # prompt_id
{'nodes': {}}, # prompt
{
'create_time': 1234567890,
'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}}
}, # extra_data
['node1'], # outputs_to_execute
)
job = normalize_queue_item(item, JobStatus.PENDING)
assert job['id'] == 'prompt-123'
assert job['status'] == 'pending'
assert job['priority'] == 10
assert job['create_time'] == 1234567890
assert 'execution_start_time' not in job
assert 'execution_end_time' not in job
assert 'execution_error' not in job
assert 'preview_output' not in job
assert job['outputs_count'] == 0
assert job['workflow_id'] == 'workflow-abc'
class TestNormalizeHistoryItem:
"""Unit tests for normalize_history_item()"""
def test_completed_job(self):
"""Completed history item should have correct status and times from messages."""
history_item = {
'prompt': (
5, # priority
'prompt-456',
{'nodes': {}},
{
'create_time': 1234567890000,
'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}}
},
['node1'],
),
'status': {
'status_str': 'success',
'completed': True,
'messages': [
('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}),
('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}),
]
},
'outputs': {},
}
job = normalize_history_item('prompt-456', history_item)
assert job['id'] == 'prompt-456'
assert job['status'] == 'completed'
assert job['priority'] == 5
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567893000
assert job['workflow_id'] == 'workflow-xyz'
def test_failed_job(self):
"""Failed history item should have failed status and error from messages."""
history_item = {
'prompt': (
5,
'prompt-789',
{'nodes': {}},
{'create_time': 1234567890000},
['node1'],
),
'status': {
'status_str': 'error',
'completed': False,
'messages': [
('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}),
('execution_error', {
'prompt_id': 'prompt-789',
'node_id': '5',
'node_type': 'KSampler',
'exception_message': 'CUDA out of memory',
'exception_type': 'RuntimeError',
'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'],
'timestamp': 1234567891000,
})
]
},
'outputs': {},
}
job = normalize_history_item('prompt-789', history_item)
assert job['status'] == 'failed'
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567891000
assert job['execution_error']['node_id'] == '5'
assert job['execution_error']['node_type'] == 'KSampler'
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
def test_include_outputs(self):
"""When include_outputs=True, should include full output data."""
history_item = {
'prompt': (
5,
'prompt-123',
{'nodes': {'1': {}}},
{'create_time': 1234567890, 'client_id': 'abc'},
['node1'],
),
'status': {'status_str': 'success', 'completed': True, 'messages': []},
'outputs': {'node1': {'images': [{'filename': 'test.png'}]}},
}
job = normalize_history_item('prompt-123', history_item, include_outputs=True)
assert 'outputs' in job
assert 'workflow' in job
assert 'execution_status' in job
assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}}
assert job['workflow'] == {
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}