diff --git a/README.md b/README.md index bae955b1b..6d09758c0 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. + - Minor versions will be used for releases off the master branch. + - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release @@ -209,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. + ### Instructions: Git clone this repo. diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2979b3ca1..1e0f86026 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC): # if multiple conds, split based on primary region if self.split_conds_to_windows and len(cond_in) > 1: region = window.get_region_index(len(cond_in)) - logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}") cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 753c66afa..0949dee44 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): def default_noise_sampler(x, seed=None): if seed is not None: + if x.device == torch.device("cpu"): + seed += 1 + generator = torch.Generator(device=x.device) generator.manual_seed(seed) else: @@ -1618,6 +1621,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non x = x + sde_noise * sigmas[i + 1] * s_noise return x +@torch.no_grad() +def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"): + """Deterministic exponential Heun second order method in data prediction (x0) and logSNR time.""" + return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type) + + +@torch.no_grad() +def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"): + """Stochastic exponential Heun second order method in data prediction (x0) and logSNR time.""" + return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type) + @torch.no_grad() def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): @@ -1765,7 +1779,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Predictor if sigmas[i + 1] == 0: # Denoising step - x = denoised + x_pred = denoised else: tau_t = tau_func(sigmas[i + 1]) curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] @@ -1786,7 +1800,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F if tau_t > 0 and s_noise > 0: noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise x_pred = x_pred + noise - return x + return x_pred @torch.no_grad() diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 5628e2ba3..afbab2ac7 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -491,7 +491,8 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) - self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + # This norm final is in the lumina 2.0 code but isn't actually used for anything. + # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) if self.pad_tokens_multiple is not None: @@ -625,7 +626,7 @@ class NextDiT(nn.Module): if pooled is not None: pooled = self.clip_text_pooled_proj(pooled) else: - pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 902af30ed..00c597535 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis): class QwenTimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding( @@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module): operations=operations ) - def forward(self, timestep, hidden_states): + self.use_additional_t_cond = use_additional_t_cond + if self.use_additional_t_cond: + self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype) + + def forward(self, timestep, hidden_states, addition_t_cond=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + if self.use_additional_t_cond: + if addition_t_cond is None: + addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long) + timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype) + return timesteps_emb @@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module): num_attention_heads: int = 24, joint_attention_dim: int = 3584, pooled_projection_dim: int = 768, - guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), default_ref_method="index", image_model=None, final_layer=True, + use_additional_t_cond=False, dtype=None, device=None, operations=None, @@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module): self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, + use_additional_t_cond=use_additional_t_cond, dtype=dtype, device=device, operations=operations @@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module): patch_size = self.patch_size hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + t_len = t h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) - return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device) - def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + if t_len > 1: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1) + else: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index + + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2) + return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape + + def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs) def _forward( self, @@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module): timesteps, context, attention_mask=None, - guidance: torch.Tensor = None, ref_latents=None, + additional_t_cond=None, transformer_options={}, control=None, **kwargs @@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module): index = 0 ref_method = kwargs.get("ref_latents_method", self.default_ref_method) index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + negative_ref_method = ref_method == "negative_index" timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 h_offset = 0 w_offset = 0 + elif negative_ref_method: + index -= 1 + h_offset = 0 + w_offset = 0 else: index = 1 h_offset = 0 @@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - if guidance is not None: - guidance = guidance * 1000 - - temb = ( - self.time_text_embed(timestep, hidden_states) - if guidance is None - else self.time_text_embed(timestep, guidance, hidden_states) - ) + temb = self.time_text_embed(timestep, hidden_states, additional_t_cond) patches_replace = transformer_options.get("patches_replace", {}) patches = transformer_options.get("patches", {}) @@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) - hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index ccbb25822..08315f1a8 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -227,6 +227,7 @@ class Encoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + input_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -245,7 +246,7 @@ class Encoder3d(nn.Module): scale = 1.0 # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1) # downsample blocks downsamples = [] @@ -331,6 +332,7 @@ class Decoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + output_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -378,7 +380,7 @@ class Decoder3d(nn.Module): # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + CausalConv3d(out_dim, output_channels, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 @@ -449,6 +451,7 @@ class WanVAE(nn.Module): num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], + image_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -460,11 +463,11 @@ class WanVAE(nn.Module): self.temperal_upsample = temperal_downsample[::-1] # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): diff --git a/comfy/model_base.py b/comfy/model_base.py index 53f953710..ef13523cb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1122,7 +1122,7 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) - clip_text_pooled = kwargs["pooled_output"] # Newbie + clip_text_pooled = kwargs.get("pooled_output", None) # NewBie if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 886409d47..8680d9c54 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: + if ctd_weight is not None: # NewBie dit_config["clip_text_dim"] = ctd_weight.shape[0] + # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 @@ -642,6 +643,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 dit_config["default_ref_method"] = "index_timestep_zero" + if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered + dit_config["use_additional_t_cond"] = True + dit_config["default_ref_method"] = "negative_index" return dit_config if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..87baedd73 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,7 @@ import importlib import platform import weakref import gc +import os class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -333,13 +334,15 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] +AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) @@ -1016,8 +1019,8 @@ NUM_STREAMS = 0 if args.async_offload is not None: NUM_STREAMS = args.async_offload else: - # Enable by default on Nvidia - if is_nvidia(): + # Enable by default on Nvidia and AMD + if is_nvidia() or is_amd(): NUM_STREAMS = 2 if args.disable_async_offload: @@ -1123,6 +1126,16 @@ if not args.disable_pinned_memory: PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +def discard_cuda_async_error(): + try: + a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + _ = a + b + torch.cuda.synchronize() + except torch.AcceleratorError: + #Dump it! We already know about it from the synchronous return + pass + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: @@ -1155,6 +1168,9 @@ def pin_memory(tensor): PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size return True + else: + logging.warning("Pin error.") + discard_cuda_async_error() return False @@ -1183,6 +1199,9 @@ def unpin_memory(tensor): if len(PINNED_MEMORY) == 0: TOTAL_PINNED_MEMORY = 0 return True + else: + logging.warning("Unpin error.") + discard_cuda_async_error() return False diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index e46971afb..9134e6d71 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,20 +122,20 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options) + return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) real_model = model.model return real_model, conds, models diff --git a/comfy/samplers.py b/comfy/samplers.py index fa4640842..1989ef107 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -720,7 +720,7 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", @@ -984,9 +984,6 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - if denoise_mask is not None: - denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) - noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) @@ -1013,6 +1010,24 @@ class CFGGuider: else: latent_shapes = [latent_image.shape] + if denoise_mask is not None: + if denoise_mask.is_nested: + denoise_masks = denoise_mask.unbind() + denoise_masks = denoise_masks[:len(latent_shapes)] + else: + denoise_masks = [denoise_mask] + + for i in range(len(denoise_masks), len(latent_shapes)): + denoise_masks.append(torch.ones(latent_shapes[i])) + + for i in range(len(denoise_masks)): + denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) + + if len(denoise_masks) > 1: + denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) + else: + denoise_mask = denoise_masks[0] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) diff --git a/comfy/sd.py b/comfy/sd.py index 5f89d2c82..69ec40756 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -56,6 +56,8 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.newbie import comfy.model_patcher import comfy.lora @@ -325,6 +327,7 @@ class VAE: self.latent_channels = 4 self.latent_dim = 2 self.output_channels = 3 + self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] @@ -450,6 +453,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.latent_channels = 64 self.output_channels = 2 + self.pad_channel_value = "replicate" self.upscale_ratio = 2048 self.downscale_ratio = 2048 self.latent_dim = 1 @@ -562,7 +566,9 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.pad_channel_value = 1.0 + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -598,6 +604,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.latent_channels = 8 self.output_channels = 2 + self.pad_channel_value = "replicate" self.upscale_ratio = 4096 self.downscale_ratio = 4096 self.latent_dim = 2 @@ -706,17 +713,28 @@ class VAE: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") def vae_encode_crop_pixels(self, pixels): - if not self.crop_input: - return pixels + if self.crop_input: + downscale_ratio = self.spacial_compression_encode() - downscale_ratio = self.spacial_compression_encode() + dims = pixels.shape[1:-1] + for d in range(len(dims)): + x = (dims[d] // downscale_ratio) * downscale_ratio + x_offset = (dims[d] % downscale_ratio) // 2 + if x != dims[d]: + pixels = pixels.narrow(d + 1, x_offset, x) - dims = pixels.shape[1:-1] - for d in range(len(dims)): - x = (dims[d] // downscale_ratio) * downscale_ratio - x_offset = (dims[d] % downscale_ratio) // 2 - if x != dims[d]: - pixels = pixels.narrow(d + 1, x_offset, x) + if pixels.shape[-1] > self.output_channels: + pixels = pixels[..., :self.output_channels] + elif pixels.shape[-1] < self.output_channels: + if self.pad_channel_value is not None: + if isinstance(self.pad_channel_value, str): + mode = self.pad_channel_value + value = None + else: + mode = "constant" + value = self.pad_channel_value + + pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) return pixels def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): @@ -1008,6 +1026,7 @@ class CLIPType(Enum): OVIS = 21 KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 + NEWBIE = 24 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1038,6 +1057,7 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 + JINA_CLIP_2 = 18 def detect_te_model(sd): @@ -1047,6 +1067,8 @@ def detect_te_model(sd): return TEModel.CLIP_H if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: return TEModel.CLIP_L + if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: + return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] if weight.shape[-1] == 4096: @@ -1207,6 +1229,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.JINA_CLIP_2: + clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper + clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper else: # clip_l if clip_type == CLIPType.SD3: @@ -1262,6 +1287,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.NEWBIE: + clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer + if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]: + clip_data_gemma = clip_data[0] + clip_data_jina = clip_data[1] + else: + clip_data_gemma = clip_data[1] + clip_data_jina = clip_data[0] + tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) + tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 962948dae..c512ca5d0 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -513,6 +513,8 @@ class SDTokenizer: self.embedding_size = embedding_size self.embedding_key = embedding_key + self.disable_weights = disable_weights + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. @@ -547,7 +549,7 @@ class SDTokenizer: min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) - if kwargs.get("disable_weights", False): + if kwargs.get("disable_weights", self.disable_weights): parsed_weights = [(text, 1.0)] else: parsed_weights = token_weights(text, 1.0) diff --git a/comfy/text_encoders/jina_clip_2.py b/comfy/text_encoders/jina_clip_2.py new file mode 100644 index 000000000..0cffb6d16 --- /dev/null +++ b/comfy/text_encoders/jina_clip_2.py @@ -0,0 +1,219 @@ +# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation: +# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py +# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py + +from dataclasses import dataclass + +import torch +from torch import nn as nn +from torch.nn import functional as F + +import comfy.model_management +import comfy.ops +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer + +class JinaClip2Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 + super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") + +# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json +@dataclass +class XLMRobertaConfig: + vocab_size: int = 250002 + type_vocab_size: int = 1 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + rotary_emb_base: float = 20000.0 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-05 + bos_token_id: int = 0 + eos_token_id: int = 2 + pad_token_id: int = 1 + +class XLMRobertaEmbeddings(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype) + self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype) + + def forward(self, input_ids=None, embeddings=None): + if input_ids is not None and embeddings is None: + embeddings = self.word_embeddings(input_ids) + + if embeddings is not None: + token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype: + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, q, k): + batch, seqlen, heads, head_dim = q.shape + self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) + + cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim) + sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim) + + def rotate_half(x): + size = x.shape[-1] // 2 + x1, x2 = x[..., :size], x[..., size:] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MHA(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = embed_dim // config.num_attention_heads + + self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device) + self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) + self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None, optimized_attention=None): + qkv = self.Wqkv(x) + batch_size, seq_len, _ = qkv.shape + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + q, k = self.rotary_emb(q, k) + + # NHD -> HND + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) + return self.out_proj(out) + +class MLP(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) + self.activation = F.gelu + self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class Block(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.mixer = MHA(config, device=device, dtype=dtype, ops=ops) + self.dropout1 = nn.Dropout(config.hidden_dropout_prob) + self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) + self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + + def forward(self, hidden_states, mask=None, optimized_attention=None): + mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention) + hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states) + mlp_out = self.mlp(hidden_states) + hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) + return hidden_states + +class XLMRobertaEncoder(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None): + optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) + for layer in self.layers: + hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) + return hidden_states + +class XLMRobertaModel_(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops) + self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + x = self.embeddings(input_ids=input_ids, embeddings=embeds) + x = self.emb_ln(x) + x = self.emb_drop(x) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1])) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + + sequence_output = self.encoder(x, attention_mask=mask) + + # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py + pooled_output = None + if attention_mask is None: + pooled_output = sequence_output.mean(dim=1) + else: + attention_mask = attention_mask.to(sequence_output.dtype) + pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True) + + # Intermediate output is not yet implemented, use None for placeholder + return sequence_output, None, pooled_output + +class XLMRobertaModel(nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.config = XLMRobertaConfig(**config_dict) + self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.model.embeddings.word_embeddings + + def set_input_embeddings(self, embeddings): + self.model.embeddings.word_embeddings = embeddings + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + +class JinaClip2TextModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) + +class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 0d07ac8c6..ed29e014d 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -3,7 +3,6 @@ import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any import math -import logging from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management @@ -177,7 +176,7 @@ class Gemma3_4B_Config: num_key_value_heads: int = 4 max_position_embeddings: int = 131072 rms_norm_eps: float = 1e-6 - rope_theta = [10000.0, 1000000.0] + rope_theta = [1000000.0, 10000.0] transformer_type: str = "gemma3" head_dim = 256 rms_norm_add = True @@ -186,8 +185,8 @@ class Gemma3_4B_Config: rope_dims = None q_norm = "gemma3" k_norm = "gemma3" - sliding_attention = [False, False, False, False, False, 1024] - rope_scale = [1.0, 8.0] + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] final_norm: bool = True class RMSNorm(nn.Module): @@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module): self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) + if config.sliding_attention is not None: self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] else: self.sliding_attention = False @@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module): if self.transformer_type == 'gemma3': if self.sliding_attention: if x.shape[1] > self.sliding_attention: - logging.warning("Warning: sliding attention not implemented, results may be incorrect") + sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) + sliding_mask.tril_(diagonal=-self.sliding_attention) + if attention_mask is not None: + attention_mask = attention_mask + sliding_mask + else: + attention_mask = sliding_mask freqs_cis = freqs_cis[1] else: freqs_cis = freqs_cis[0] diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 7a6cfdab2..b29a7cc87 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py new file mode 100644 index 000000000..db2324576 --- /dev/null +++ b/comfy/text_encoders/newbie.py @@ -0,0 +1,62 @@ +import torch + +import comfy.model_management +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.lumina2 + +class NewBieTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) + out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + raise NotImplementedError + + def state_dict(self): + return {} + +class NewBieTEModel(torch.nn.Module): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) + self.dtypes = {dtype, dtype_gemma} + + def set_clip_options(self, options): + self.gemma.set_clip_options(options) + self.jina.set_clip_options(options) + + def reset_clip_options(self): + self.gemma.reset_clip_options() + self.jina.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_gemma = token_weight_pairs["gemma"] + token_weight_pairs_jina = token_weight_pairs["jina"] + + gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma) + jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina) + + return gemma_out, jina_pooled, gemma_extra + + def load_sd(self, sd): + if "model.layers.0.self_attn.q_norm.weight" in sd: + return self.gemma.load_sd(sd) + else: + return self.jina.load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class NewBieTEModel_(NewBieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_ diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 2b634d172..ba0b95498 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,9 +28,8 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr prune_dict, shallow_clone_class) from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL +from ._util import MESH, VOXEL, SVG as _SVG -# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference class FolderType(str, Enum): input = "input" @@ -656,7 +655,7 @@ class Video(ComfyTypeIO): @comfytype(io_type="SVG") class SVG(ComfyTypeIO): - Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + Type = _SVG @comfytype(io_type="LORA_MODEL") class LoraModel(ComfyTypeIO): @@ -1556,12 +1555,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) return type_clone @final diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index fc5431dda..6313eb01b 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,6 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents from .geometry_types import VOXEL, MESH +from .image_types import SVG __all__ = [ # Utility Types @@ -8,4 +9,5 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SVG", ] diff --git a/comfy_api/latest/_util/image_types.py b/comfy_api/latest/_util/image_types.py new file mode 100644 index 000000000..f031ed426 --- /dev/null +++ b/comfy_api/latest/_util/image_types.py @@ -0,0 +1,18 @@ +from io import BytesIO + + +class SVG: + """Stores SVG representations via a list of BytesIO objects.""" + + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: 'SVG') -> 'SVG': + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list['SVG']) -> 'SVG': + all_svgs_list: list[BytesIO] = [] + for svg_item in svgs: + all_svgs_list.extend(svg_item.data) + return SVG(all_svgs_list) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py index 77cd76f9b..b8c2f618b 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel): size: str | None = Field(None) seed: int | None = Field(0, ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Image2ImageTaskCreationRequest(BaseModel): @@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel): size: str | None = Field("adaptive") seed: int | None = Field(..., ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Seedream4Options(BaseModel): @@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel): seed: int = Field(..., ge=0, le=2147483647) sequential_image_generation: str = Field("disabled") sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) + watermark: bool = Field(False) class ImageTaskCreationResponse(BaseModel): diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f8edc38c9..d81337dae 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) + uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 80a758466..bf54ede3e 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel): prompt: str = Field(...) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + + +class MotionControlRequest(BaseModel): + prompt: str = Field(...) + image_url: str = Field(...) + video_url: str = Field(...) + keep_original_sound: str = Field(...) + character_orientation: str = Field(...) + mode: str = Field(..., description="'pro' or 'std'") diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py new file mode 100644 index 000000000..ae5bb2673 --- /dev/null +++ b/comfy_api_nodes/apis/openai_api.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = None + text_tokens: int | None = None + + +class Usage(BaseModel): + input_tokens: int | None = None + input_tokens_details: InputTokensDetails | None = None + output_tokens: int | None = None + total_tokens: int | None = None + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = None + usage: Usage | None = None + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 8826dea0c..ce077d6b3 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,10 +1,8 @@ -from inspect import cleandoc - import torch from pydantic import BaseModel from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, @@ -28,7 +26,7 @@ from comfy_api_nodes.util import ( ) -def convert_mask_to_image(mask: torch.Tensor): +def convert_mask_to_image(mask: Input.Image): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ @@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor): class FluxProUltraImageNode(IO.ComfyNode): - """ - Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode): node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( "prompt", @@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: torch.Tensor | None = None, + image_prompt: Input.Image | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode): - """ - Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode): node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: torch.Tensor | None = None, + input_image: Input.Image | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextMaxImageNode(FluxKontextProImageNode): - """ - Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio. - """ - DESCRIPTION = cleandoc(__doc__ or "") + DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio." BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" NODE_ID = "FluxKontextMaxImageNode" DISPLAY_NAME = "Flux.1 Kontext [max] Image" class FluxProExpandNode(IO.ComfyNode): - """ - Outpaints image based on prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode): node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, prompt_upsampling: bool, top: int, @@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode): - """ - Inpaints image based on mask and prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode): node_id="FluxProFillNode", display_name="Flux.1 Fill Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), IO.Mask.Input("mask"), @@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - mask: torch.Tensor, + image: Input.Image, + mask: Input.Image, prompt: str, prompt_upsampling: bool, steps: int, @@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode): + NODE_ID = "Flux2ProImageNode" + DISPLAY_NAME = "Flux.2 [pro] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( - node_id="Flux2ProImageNode", - display_name="Flux.2 [pro] Image", + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, category="api node/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ @@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode): ), IO.Boolean.Input( "prompt_upsampling", - default=False, + default=True, tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", + "If active, automatically modifies the prompt for more creative generation.", ), - IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."), ], outputs=[IO.Image.Output()], hidden=[ @@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode): height: int, seed: int, prompt_upsampling: bool, - images: torch.Tensor | None = None, + images: Input.Image | None = None, ) -> IO.NodeOutput: reference_images = {} if images is not None: @@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode): reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) initial_response = await sync_op( cls, - ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + ApiEndpoint(path=cls.API_ENDPOINT, method="POST"), response_model=BFLFluxProGenerateResponse, data=Flux2ProGenerateRequest( prompt=prompt, @@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2MaxImageNode(Flux2ProImageNode): + + NODE_ID = "Flux2MaxImageNode" + DISPLAY_NAME = "Flux.2 [max] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension): FluxProExpandNode, FluxProFillNode, Flux2ProImageNode, + Flux2MaxImageNode, ] diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 57c0218d0..636cc1265 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -346,7 +346,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), @@ -380,7 +380,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): sequential_image_generation: str = "disabled", max_images: int = 1, seed: int = 0, - watermark: bool = True, + watermark: bool = False, fail_on_partial: bool = True, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -507,7 +507,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -617,7 +617,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -739,7 +739,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -862,7 +862,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ad0f4b4d1..e8ed7e797 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -34,6 +34,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + download_url_to_image_tensor, get_number_of_images, sync_op, tensor_to_base64_string, @@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera ) parts = [] for part in response.candidates[0].content.parts: - if part_type == "text" and hasattr(part, "text") and part.text: + if part_type == "text" and part.text: parts.append(part) - elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: - image_data = base64.b64decode(part.inlineData.data) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + if part.inlineData: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + else: + returned_image = await download_url_to_image_tensor(part.fileData.fileUri) image_tensors.append(returned_image) if len(image_tensors) == 0: return torch.zeros((1, 1024, 1024, 4)) @@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode): response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode): response = await sync_op( cls, - ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 1a6364fa0..58259e029 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -51,6 +51,7 @@ from comfy_api_nodes.apis import ( ) from comfy_api_nodes.apis.kling_api import ( ImageToVideoWithAudioRequest, + MotionControlRequest, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -858,7 +859,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): tooltip="A text prompt describing the video content. " "This can include both positive and negative descriptions.", ), - IO.Combo.Input("duration", options=["5", "10"]), + IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider), IO.Image.Input("first_frame"), IO.Image.Input( "end_frame", @@ -897,6 +898,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): validate_string(prompt, min_length=1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + if duration not in (5, 10) and end_frame is None and reference_images is None: + raise ValueError( + "Duration is only supported for 5 or 10 seconds if there is no end frame or reference images." + ) validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [ @@ -2159,6 +2164,91 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +class MotionControl(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingMotionControl", + display_name="Kling Motion Control", + category="api node/video/Kling", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.Image.Input("reference_image"), + IO.Video.Input( + "reference_video", + tooltip="Motion reference video used to drive movement/expression.\n" + "Duration limits depend on character_orientation:\n" + " - image: 3–10s (max 10s)\n" + " - video: 3–30s (max 30s)", + ), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Combo.Input( + "character_orientation", + options=["video", "image"], + tooltip="Controls where the character's facing/orientation comes from.\n" + "video: movements, expressions, camera moves, and orientation " + "follow the motion reference video (other details via prompt).\n" + "image: movements and expressions still follow the motion reference video, " + "but the character orientation matches the reference image (camera/other details via prompt).", + ), + IO.Combo.Input("mode", options=["pro", "std"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + reference_image: Input.Image, + reference_video: Input.Video, + keep_original_sound: bool, + character_orientation: str, + mode: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2500) + validate_image_dimensions(reference_image, min_width=340, min_height=340) + validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1)) + if character_orientation == "image": + validate_video_duration(reference_video, min_duration=3, max_duration=10) + else: + validate_video_duration(reference_video, min_duration=3, max_duration=30) + validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"), + response_model=TaskStatusResponse, + data=MotionControlRequest( + prompt=prompt, + image_url=(await upload_images_to_comfyapi(cls, reference_image))[0], + video_url=await upload_video_to_comfyapi(cls, reference_video), + keep_original_sound="yes" if keep_original_sound else "no", + character_orientation=character_orientation, + mode=mode, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2184,6 +2274,7 @@ class KlingExtension(ComfyExtension): OmniProImageNode, TextToVideoWithAudio, ImageToVideoWithAudio, + MotionControl, ] diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index c8da5464b..a6205a34f 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,46 +1,45 @@ -from io import BytesIO +import base64 import os from enum import Enum -from inspect import cleandoc +from io import BytesIO + import numpy as np import torch from PIL import Image -import folder_paths -import base64 -from comfy_api.latest import IO, ComfyExtension from typing_extensions import override - +import folder_paths +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import ( - OpenAIImageGenerationRequest, - OpenAIImageEditRequest, - OpenAIImageGenerationResponse, - OpenAICreateResponse, - OpenAIResponse, CreateModelResponseProperties, - Item, - OutputContent, - InputImageContent, Detail, - InputTextContent, - InputMessage, - InputMessageContentList, InputContent, InputFileContent, + InputImageContent, + InputMessage, + InputMessageContentList, + InputTextContent, + Item, + OpenAICreateResponse, + OpenAIResponse, + OutputContent, +) +from comfy_api_nodes.apis.openai_api import ( + OpenAIImageEditRequest, + OpenAIImageGenerationRequest, + OpenAIImageGenerationResponse, ) - from comfy_api_nodes.util import ( - downscale_image_tensor, - download_url_to_bytesio, - validate_string, - tensor_to_base64_string, ApiEndpoint, - sync_op, + download_url_to_bytesio, + downscale_image_tensor, poll_op, + sync_op, + tensor_to_base64_string, text_filepath_to_data_uri, + validate_string, ) - RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" STARTING_POINT_ID_PATTERN = r"" @@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten class OpenAIDalle2(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's DALL·E 2 endpoint. - """ @classmethod def define_schema(cls): @@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode): node_id="OpenAIDalle2", display_name="OpenAI DALL·E 2", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", inputs=[ IO.String.Input( "prompt", @@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode): class OpenAIDalle3(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's DALL·E 3 endpoint. - """ @classmethod def define_schema(cls): @@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode): node_id="OpenAIDalle3", display_name="OpenAI DALL·E 3", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", inputs=[ IO.String.Input( "prompt", @@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode): return IO.NodeOutput(await validate_and_cast_response(response)) +def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None: + # https://platform.openai.com/docs/pricing + return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0 + + +def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None: + return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 + + class OpenAIGPTImage1(IO.ComfyNode): - """ - Generates images synchronously via OpenAI's GPT Image 1 endpoint. - """ @classmethod def define_schema(cls): @@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode): node_id="OpenAIGPTImage1", display_name="OpenAI GPT Image 1", category="api node/image/OpenAI", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", inputs=[ IO.String.Input( "prompt", default="", multiline=True, - tooltip="Text prompt for GPT Image 1", + tooltip="Text prompt for GPT Image", ), IO.Int.Input( "seed", @@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode): ), IO.Combo.Input( "background", - default="opaque", - options=["opaque", "transparent"], + default="auto", + options=["auto", "opaque", "transparent"], tooltip="Return image with or without background", optional=True, ), @@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode): tooltip="Optional mask for inpainting (white areas will be replaced)", optional=True, ), + IO.Combo.Input( + "model", + options=["gpt-image-1", "gpt-image-1.5"], + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode): @classmethod async def execute( cls, - prompt, - seed=0, - quality="low", - background="opaque", - image=None, - mask=None, - n=1, - size="1024x1024", + prompt: str, + seed: int = 0, + quality: str = "low", + background: str = "opaque", + image: Input.Image | None = None, + mask: Input.Image | None = None, + n: int = 1, + size: str = "1024x1024", + model: str = "gpt-image-1", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - model = "gpt-image-1" - path = "/proxy/openai/images/generations" - content_type = "application/json" - request_class = OpenAIImageGenerationRequest - files = [] + + if mask is not None and image is None: + raise ValueError("Cannot use a mask without an input image") + + if model == "gpt-image-1": + price_extractor = calculate_tokens_price_image_1 + elif model == "gpt-image-1.5": + price_extractor = calculate_tokens_price_image_1_5 + else: + raise ValueError(f"Unknown model: {model}") if image is not None: - path = "/proxy/openai/images/edits" - request_class = OpenAIImageEditRequest - content_type = "multipart/form-data" - + files = [] batch_size = image.shape[0] - for i in range(batch_size): - single_image = image[i : i + 1] - scaled_image = downscale_image_tensor(single_image).squeeze() + single_image = image[i: i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode): else: files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) - if mask is not None: - if image is None: - raise Exception("Cannot use a mask without an input image") - if image.shape[0] != 1: - raise Exception("Cannot use a mask with multiple image") - if mask.shape[1:] != image.shape[1:-1]: - raise Exception("Mask and Image must be the same size") - batch, height, width = mask.shape - rgba_mask = torch.zeros(height, width, 4, device="cpu") - rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() + if mask is not None: + if image.shape[0] != 1: + raise Exception("Cannot use a mask with multiple image") + if mask.shape[1:] != image.shape[1:-1]: + raise Exception("Mask and Image must be the same size") + _, height, width = mask.shape + rgba_mask = torch.zeros(height, width, 4, device="cpu") + rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() - mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) - mask_img = Image.fromarray(mask_np) - mask_img_byte_arr = BytesIO() - mask_img.save(mask_img_byte_arr, format="PNG") - mask_img_byte_arr.seek(0) - files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) - - # Build the operation - response = await sync_op( - cls, - ApiEndpoint(path=path, method="POST"), - response_model=OpenAIImageGenerationResponse, - data=request_class( - model=model, - prompt=prompt, - quality=quality, - background=background, - n=n, - seed=seed, - size=size, - ), - files=files if files else None, - content_type=content_type, - ) + mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_img_byte_arr = BytesIO() + mask_img.save(mask_img_byte_arr, format="PNG") + mask_img_byte_arr.seek(0) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/edits", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageEditRequest( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + moderation="low", + ), + content_type="multipart/form-data", + files=files, + price_extractor=price_extractor, + ) + else: + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageGenerationRequest( + model=model, + prompt=prompt, + quality=quality, + background=background, + n=n, + seed=seed, + size=size, + moderation="low", + ), + price_extractor=price_extractor, + ) return IO.NodeOutput(await validate_and_cast_response(response)) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index f522756e5..b04575ad8 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = { "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", } -UPSCALER_VALUES_MAP = { - "FullHD (1080p)": 1920, - "4K (2160p)": 3840, -} class TopazImageEnhance(IO.ComfyNode): @@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), - IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", options=["low", "middle", "high"], @@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode): target_frame_rate = src_frame_rate filters = [] if upscaler_enabled: - target_width = UPSCALER_VALUES_MAP[upscaler_resolution] - target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + if "1080p" in upscaler_resolution: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 filters.append( topaz_api.VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e165b8380..13a6bfd91 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Only add generateAudio for Veo 3 models if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio + # force "enhance_prompt" to True for Veo3 models + parameters["enhancePrompt"] = True initial_response = await sync_op( cls, @@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Boolean.Input( "enhance_prompt", default=True, - tooltip="Whether to enhance the prompt with AI assistance", + tooltip="This parameter is deprecated and ignored.", optional=True, ), IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 17b680e13..1675fd863 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel): n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) class Image2ImageParametersField(BaseModel): size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) - watermark: bool = Field(True) + watermark: bool = Field(False) class Text2VideoParametersField(BaseModel): @@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode): height: int = 1024, seed: int = 0, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, ): initial_response = await sync_op( cls, @@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode): # width: int = 1024, # height: int = 1024, seed: int = 0, - watermark: bool = True, + watermark: bool = False, ): n_images = get_number_of_images(image) if n_images not in (1, 2): @@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if "480p" in size and model == "wan2.6-t2v": @@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if get_number_of_images(image) != 1: diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf37cba5f..f372ec7b5 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -430,9 +430,9 @@ def _display_text( if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: - p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": - display_lines.append(f"Price: ${p}") + display_lines.append(f"Price: {p} credits") if text is not None: display_lines.append(text) if display_lines: diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index c57457580..d64239c86 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: return img_byte_arr -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: +def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor: """Downscale input image tensor to roughly the specified total pixels.""" samples = image.movedim(-1, 1) total = int(total_pixels) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py new file mode 100644 index 000000000..59fb49357 --- /dev/null +++ b/comfy_execution/jobs.py @@ -0,0 +1,291 @@ +""" +Job utilities for the /api/jobs endpoint. +Provides normalization and helper functions for job status tracking. +""" + +from typing import Optional + +from comfy_api.internal import prune_dict + + +class JobStatus: + """Job status constants.""" + PENDING = 'pending' + IN_PROGRESS = 'in_progress' + COMPLETED = 'completed' + FAILED = 'failed' + + ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED] + + +# Media types that can be previewed in the frontend +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'}) + +# 3D file extensions for preview fallback (no dedicated media_type exists) +THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'}) + + +def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: + """Extract create_time and workflow_id from extra_data. + + Returns: + tuple: (create_time, workflow_id) + """ + create_time = extra_data.get('create_time') + extra_pnginfo = extra_data.get('extra_pnginfo', {}) + workflow_id = extra_pnginfo.get('workflow', {}).get('id') + return create_time, workflow_id + + +def is_previewable(media_type: str, item: dict) -> bool: + """ + Check if an output item is previewable. + Matches frontend logic in ComfyUI_frontend/src/stores/queueStore.ts + Maintains backwards compatibility with existing logic. + + Priority: + 1. media_type is 'images', 'video', or 'audio' + 2. format field starts with 'video/' or 'audio/' + 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb) + """ + if media_type in PREVIEWABLE_MEDIA_TYPES: + return True + + # Check format field (MIME type). + # Maintains backwards compatibility with how custom node outputs are handled in the frontend. + fmt = item.get('format', '') + if fmt and (fmt.startswith('video/') or fmt.startswith('audio/')): + return True + + # Check for 3D files by extension + filename = item.get('filename', '').lower() + if any(filename.endswith(ext) for ext in THREE_D_EXTENSIONS): + return True + + return False + + +def normalize_queue_item(item: tuple, status: str) -> dict: + """Convert queue item tuple to unified job dict. + + Expects item with sensitive data already removed (5 elements). + """ + priority, prompt_id, _, extra_data, _ = item + create_time, workflow_id = _extract_job_metadata(extra_data) + + return prune_dict({ + 'id': prompt_id, + 'status': status, + 'priority': priority, + 'create_time': create_time, + 'outputs_count': 0, + 'workflow_id': workflow_id, + }) + + +def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: bool = False) -> dict: + """Convert history item dict to unified job dict. + + History items have sensitive data already removed (prompt tuple has 5 elements). + """ + prompt_tuple = history_item['prompt'] + priority, _, prompt, extra_data, _ = prompt_tuple + create_time, workflow_id = _extract_job_metadata(extra_data) + + status_info = history_item.get('status', {}) + status_str = status_info.get('status_str') if status_info else None + if status_str == 'success': + status = JobStatus.COMPLETED + elif status_str == 'error': + status = JobStatus.FAILED + else: + status = JobStatus.COMPLETED + + outputs = history_item.get('outputs', {}) + outputs_count, preview_output = get_outputs_summary(outputs) + + execution_error = None + execution_start_time = None + execution_end_time = None + if status_info: + messages = status_info.get('messages', []) + for entry in messages: + if isinstance(entry, (list, tuple)) and len(entry) >= 2: + event_name, event_data = entry[0], entry[1] + if isinstance(event_data, dict): + if event_name == 'execution_start': + execution_start_time = event_data.get('timestamp') + elif event_name in ('execution_success', 'execution_error', 'execution_interrupted'): + execution_end_time = event_data.get('timestamp') + if event_name == 'execution_error': + execution_error = event_data + + job = prune_dict({ + 'id': prompt_id, + 'status': status, + 'priority': priority, + 'create_time': create_time, + 'execution_start_time': execution_start_time, + 'execution_end_time': execution_end_time, + 'execution_error': execution_error, + 'outputs_count': outputs_count, + 'preview_output': preview_output, + 'workflow_id': workflow_id, + }) + + if include_outputs: + job['outputs'] = outputs + job['execution_status'] = status_info + job['workflow'] = { + 'prompt': prompt, + 'extra_data': extra_data, + } + + return job + + +def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: + """ + Count outputs and find preview in a single pass. + Returns (outputs_count, preview_output). + + Preview priority (matching frontend): + 1. type="output" with previewable media + 2. Any previewable media + """ + count = 0 + preview_output = None + fallback_preview = None + + for node_id, node_outputs in outputs.items(): + if not isinstance(node_outputs, dict): + continue + for media_type, items in node_outputs.items(): + # 'animated' is a boolean flag, not actual output items + if media_type == 'animated' or not isinstance(items, list): + continue + + for item in items: + if not isinstance(item, dict): + continue + count += 1 + + if preview_output is None and is_previewable(media_type, item): + enriched = { + **item, + 'nodeId': node_id, + 'mediaType': media_type + } + if item.get('type') == 'output': + preview_output = enriched + elif fallback_preview is None: + fallback_preview = enriched + + return count, preview_output or fallback_preview + + +def apply_sorting(jobs: list[dict], sort_by: str, sort_order: str) -> list[dict]: + """Sort jobs list by specified field and order.""" + reverse = (sort_order == 'desc') + + if sort_by == 'execution_duration': + def get_sort_key(job): + start = job.get('execution_start_time', 0) + end = job.get('execution_end_time', 0) + return end - start if end and start else 0 + else: + def get_sort_key(job): + return job.get('create_time', 0) + + return sorted(jobs, key=get_sort_key, reverse=reverse) + + +def get_job(prompt_id: str, running: list, queued: list, history: dict) -> Optional[dict]: + """ + Get a single job by prompt_id from history or queue. + + Args: + prompt_id: The prompt ID to look up + running: List of currently running queue items + queued: List of pending queue items + history: Dict of history items keyed by prompt_id + + Returns: + Job dict with full details, or None if not found + """ + if prompt_id in history: + return normalize_history_item(prompt_id, history[prompt_id], include_outputs=True) + + for item in running: + if item[1] == prompt_id: + return normalize_queue_item(item, JobStatus.IN_PROGRESS) + + for item in queued: + if item[1] == prompt_id: + return normalize_queue_item(item, JobStatus.PENDING) + + return None + + +def get_all_jobs( + running: list, + queued: list, + history: dict, + status_filter: Optional[list[str]] = None, + workflow_id: Optional[str] = None, + sort_by: str = "created_at", + sort_order: str = "desc", + limit: Optional[int] = None, + offset: int = 0 +) -> tuple[list[dict], int]: + """ + Get all jobs (running, pending, completed) with filtering and sorting. + + Args: + running: List of currently running queue items + queued: List of pending queue items + history: Dict of history items keyed by prompt_id + status_filter: List of statuses to include (from JobStatus.ALL) + workflow_id: Filter by workflow ID + sort_by: Field to sort by ('created_at', 'execution_duration') + sort_order: 'asc' or 'desc' + limit: Maximum number of items to return + offset: Number of items to skip + + Returns: + tuple: (jobs_list, total_count) + """ + jobs = [] + + if status_filter is None: + status_filter = JobStatus.ALL + + if JobStatus.IN_PROGRESS in status_filter: + for item in running: + jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS)) + + if JobStatus.PENDING in status_filter: + for item in queued: + jobs.append(normalize_queue_item(item, JobStatus.PENDING)) + + include_completed = JobStatus.COMPLETED in status_filter + include_failed = JobStatus.FAILED in status_filter + if include_completed or include_failed: + for prompt_id, history_item in history.items(): + is_failed = history_item.get('status', {}).get('status_str') == 'error' + if (is_failed and include_failed) or (not is_failed and include_completed): + jobs.append(normalize_history_item(prompt_id, history_item)) + + if workflow_id: + jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] + + jobs = apply_sorting(jobs, sort_by, sort_order) + + total_count = len(jobs) + + if offset > 0: + jobs = jobs[offset:] + if limit is not None: + jobs = jobs[:limit] + + return (jobs, total_count) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 71ea4e9ec..f19adf4b9 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -9,6 +9,7 @@ import comfy.utils import node_helpers from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import re class BasicScheduler(io.ComfyNode): @@ -671,7 +672,16 @@ class SamplerSEEDS2(io.ComfyNode): io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"), io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"), ], - outputs=[io.Sampler.Output()] + outputs=[io.Sampler.Output()], + description=( + "This sampler node can represent multiple samplers:\n\n" + "seeds_2\n" + "- default setting\n\n" + "exp_heun_2_x0\n" + "- solver_type=phi_2, r=1.0, eta=0.0\n\n" + "exp_heun_2_x0_sde\n" + "- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0" + ) ) @classmethod @@ -751,8 +761,12 @@ class SamplerCustom(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -939,8 +953,12 @@ class SamplerCustomAdvanced(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -996,6 +1014,25 @@ class AddNoise(io.ComfyNode): add_noise = execute +class ManualSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ManualSigmas", + category="_for_testing/custom_sampling", + is_experimental=True, + inputs=[ + io.String.Input("sigmas", default="1, 0.5", multiline=False) + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas) + sigmas = [float(i) for i in sigmas] + sigmas = torch.FloatTensor(sigmas) + return io.NodeOutput(sigmas) class CustomSamplersExtension(ComfyExtension): @override @@ -1035,6 +1072,7 @@ class CustomSamplersExtension(ComfyExtension): DisableNoise, AddNoise, SamplerCustomAdvanced, + ManualSigmas, ] diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 4789d7d53..5ef851bd0 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): @classmethod def _process(cls, image, longer_edge): - img = tensor_to_pil(image) - w, h = img.size - if w > h: - new_w = longer_edge - new_h = int(h * (longer_edge / w)) - else: - new_h = longer_edge - new_w = int(w * (longer_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) + resized_images = [] + for image_i in image: + img = tensor_to_pil(image_i) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + resized_images.append(pil_to_tensor(img)) + return torch.cat(resized_images, dim=0) class CenterCropImagesNode(ImageProcessingNode): @@ -1125,6 +1128,99 @@ class MergeTextListsNode(TextProcessingNode): # ========== Training Dataset Nodes ========== +class ResolutionBucket(io.ComfyNode): + """Bucket latents and conditions by resolution for efficient batch training.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionBucket", + display_name="Resolution Bucket", + category="dataset", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts to bucket by resolution.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists (must match latents length).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of batched latent dicts, one per resolution bucket.", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of condition lists, one per resolution bucket.", + ), + ], + ) + + @classmethod + def execute(cls, latents, conditioning): + # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1 + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})." + ) + + # Flatten latents and conditions to individual samples + flat_latents = [] # list of (C, H, W) tensors + flat_conditions = [] # list of condition lists + + for latent_dict, cond in zip(latents, conditioning): + samples = latent_dict["samples"] # (B, C, H, W) + batch_size = samples.shape[0] + + # cond is a list of conditions with length == batch_size + for i in range(batch_size): + flat_latents.append(samples[i]) # (C, H, W) + flat_conditions.append(cond[i]) # single condition + + # Group by resolution (H, W) + buckets = {} # (H, W) -> {"latents": list, "conditions": list} + + for latent, cond in zip(flat_latents, flat_conditions): + # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W) + h, w = latent.shape[-2], latent.shape[-1] + key = (h, w) + + if key not in buckets: + buckets[key] = {"latents": [], "conditions": []} + + buckets[key]["latents"].append(latent) + buckets[key]["conditions"].append(cond) + + # Convert buckets to output format + output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W) + output_conditions = [] # list[list[cond]] where each inner list has Bi conditions + + for (h, w), bucket_data in buckets.items(): + # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W) + stacked_latents = torch.stack(bucket_data["latents"], dim=0) + output_latents.append({"samples": stacked_latents}) + + # Conditions stay as list of condition lists + output_conditions.append(bucket_data["conditions"]) + + logging.info( + f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples" + ) + + logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples") + return io.NodeOutput(output_latents, output_conditions) + + class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" @@ -1373,7 +1469,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f, weights_only=True) + shard_data = torch.load(f) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1425,6 +1521,7 @@ class DatasetExtension(ComfyExtension): MakeTrainingDataset, SaveTrainingDataset, LoadTrainingDataset, + ResolutionBucket, ] diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 392aea32c..ce21caade 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -2,280 +2,231 @@ from __future__ import annotations import nodes import folder_paths -from comfy.cli_args import args -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import numpy as np import json import os import re -from io import BytesIO -from inspect import cleandoc import torch import comfy.utils -from comfy.comfy_types import FileLocator, IO from server import PromptServer +from comfy_api.latest import ComfyExtension, IO, UI +from typing_extensions import override + +SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. MAX_RESOLUTION = nodes.MAX_RESOLUTION -class ImageCrop: +class ImageCrop(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "crop" + def define_schema(cls): + return IO.Schema( + node_id="ImageCrop", + display_name="Image Crop", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def crop(self, image, width, height, x, y): + @classmethod + def execute(cls, image, width, height, x, y) -> IO.NodeOutput: x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] - return (img,) + return IO.NodeOutput(img) -class RepeatImageBatch: + crop = execute # TODO: remove + + +class RepeatImageBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="RepeatImageBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("amount", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def repeat(self, image, amount): + @classmethod + def execute(cls, image, amount) -> IO.NodeOutput: s = image.repeat((amount, 1,1,1)) - return (s,) + return IO.NodeOutput(s) -class ImageFromBatch: + repeat = execute # TODO: remove + + +class ImageFromBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), - "length": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "frombatch" + def define_schema(cls): + return IO.Schema( + node_id="ImageFromBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("batch_index", default=0, min=0, max=4095), + IO.Int.Input("length", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def frombatch(self, image, batch_index, length): + @classmethod + def execute(cls, image, batch_index, length) -> IO.NodeOutput: s_in = image batch_index = min(s_in.shape[0] - 1, batch_index) length = min(s_in.shape[0] - batch_index, length) s = s_in[batch_index:batch_index + length].clone() - return (s,) + return IO.NodeOutput(s) + + frombatch = execute # TODO: remove -class ImageAddNoise: +class ImageAddNoise(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), - "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="ImageAddNoise", + category="image", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def repeat(self, image, seed, strength): + @classmethod + def execute(cls, image, seed, strength) -> IO.NodeOutput: generator = torch.manual_seed(seed) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) - return (s,) + return IO.NodeOutput(s) -class SaveAnimatedWEBP: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" + repeat = execute # TODO: remove - methods = {"default": 4, "fastest": 0, "slowest": 6} - @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "lossless": ("BOOLEAN", {"default": True}), - "quality": ("INT", {"default": 80, "min": 0, "max": 100}), - "method": (list(s.methods.keys()),), - # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method) - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results: list[FileLocator] = [] - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = pil_images[0].getexif() - if not args.disable_metadata: - if prompt is not None: - metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) - if extra_pnginfo is not None: - inital_exif = 0x010f - for x in extra_pnginfo: - metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) - inital_exif -= 1 - - if num_frames == 0: - num_frames = len(pil_images) - - c = len(pil_images) - for i in range(0, c, num_frames): - file = f"{filename}_{counter:05}_.webp" - pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - animated = num_frames != 1 - return { "ui": { "images": results, "animated": (animated,) } } - -class SaveAnimatedPNG: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAnimatedWEBP(IO.ComfyNode): + COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedWEBP", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Boolean.Input("lossless", default=True), + IO.Int.Input("quality", default=80, min=0, max=100), + IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - RETURN_TYPES = () - FUNCTION = "save_images" + @classmethod + def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_webp_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=cls.COMPRESS_METHODS.get(method) + ) + ) - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results = list() - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) - - file = f"{filename}_{counter:05}_.png" - pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - - return { "ui": { "images": results, "animated": (True,)} } - -class SVG: - """ - Stores SVG representations via a list of BytesIO objects. - """ - def __init__(self, data: list[BytesIO]): - self.data = data - - def combine(self, other: 'SVG') -> 'SVG': - return SVG(self.data + other.data) - - @staticmethod - def combine_all(svgs: list['SVG']) -> 'SVG': - all_svgs_list: list[BytesIO] = [] - for svg_item in svgs: - all_svgs_list.extend(svg_item.data) - return SVG(all_svgs_list) + save_images = execute # TODO: remove -class ImageStitch: +class SaveAnimatedPNG(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedPNG", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Int.Input("compress_level", default=4, min=0, max=9), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_png_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + ) + + save_images = execute # TODO: remove + + +class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "direction": (["right", "down", "left", "up"], {"default": "right"}), - "match_image_size": ("BOOLEAN", {"default": True}), - "spacing_width": ( - "INT", - {"default": 0, "min": 0, "max": 1024, "step": 2}, - ), - "spacing_color": ( - ["white", "black", "red", "green", "blue"], - {"default": "white"}, - ), - }, - "optional": { - "image2": ("IMAGE",), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="ImageStitch", + display_name="Image Stitch", + description="Stitches image2 to image1 in the specified direction.\n" + "If image2 is not provided, returns image1 unchanged.\n" + "Optional spacing can be added between images.", + category="image/transform", + inputs=[ + IO.Image.Input("image1"), + IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), + IO.Boolean.Input("match_image_size", default=True), + IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2), + IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"), + IO.Image.Input("image2", optional=True), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "stitch" - CATEGORY = "image/transform" - DESCRIPTION = """ -Stitches image2 to image1 in the specified direction. -If image2 is not provided, returns image1 unchanged. -Optional spacing can be added between images. -""" - - def stitch( - self, + @classmethod + def execute( + cls, image1, direction, match_image_size, spacing_width, spacing_color, image2=None, - ): + ) -> IO.NodeOutput: if image2 is None: - return (image1,) + return IO.NodeOutput(image1) # Handle batch size differences if image1.shape[0] != image2.shape[0]: @@ -412,36 +363,30 @@ Optional spacing can be added between images. images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 - return (torch.cat(images, dim=concat_dim),) + return IO.NodeOutput(torch.cat(images, dim=concat_dim)) + + stitch = execute # TODO: remove + + +class ResizeAndPadImage(IO.ComfyNode): -class ResizeAndPadImage: @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE",), - "target_width": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "target_height": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "padding_color": (["white", "black"],), - "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ResizeAndPadImage", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("padding_color", options=["white", "black"]), + IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "resize_and_pad" - CATEGORY = "image/transform" - - def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + @classmethod + def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: batch_size, orig_height, orig_width, channels = image.shape scale_w = target_width / orig_width @@ -469,52 +414,47 @@ class ResizeAndPadImage: padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized output = padded.permute(0, 2, 3, 1) - return (output,) + return IO.NodeOutput(output) -class SaveSVGNode: - """ - Save SVG files on disk. - """ + resize_and_pad = execute # TODO: remove - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" - RETURN_TYPES = () - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "save_svg" - CATEGORY = "image/save" # Changed - OUTPUT_NODE = True +class SaveSVGNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "svg": ("SVG",), # Changed - "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - } - } + def define_schema(cls): + return IO.Schema( + node_id="SaveSVGNode", + description="Save SVG files on disk.", + category="image/save", + inputs=[ + IO.SVG.Input("svg"), + IO.String.Input( + "filename_prefix", + default="svg/ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() + @classmethod + def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) + results: list[UI.SavedResult] = [] # Prepare metadata JSON metadata_dict = {} - if prompt is not None: - metadata_dict["prompt"] = prompt - if extra_pnginfo is not None: - metadata_dict.update(extra_pnginfo) + if cls.hidden.prompt is not None: + metadata_dict["prompt"] = cls.hidden.prompt + if cls.hidden.extra_pnginfo is not None: + metadata_dict.update(cls.hidden.extra_pnginfo) # Convert metadata to JSON string metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + for batch_number, svg_bytes in enumerate(svg.data): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.svg" @@ -544,57 +484,64 @@ class SaveSVGNode: with open(os.path.join(full_output_folder, file), 'wb') as svg_file: svg_file.write(svg_content.encode('utf-8')) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) + results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return { "ui": { "images": results } } + return IO.NodeOutput(ui={"images": results}) -class GetImageSize: + save_svg = execute # TODO: remove + + +class GetImageSize(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "hidden": { - "unique_id": "UNIQUE_ID", - } - } + def define_schema(cls): + return IO.Schema( + node_id="GetImageSize", + display_name="Get Image Size", + description="Returns width and height of the image, and passes it through unchanged.", + category="image", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + IO.Int.Output(display_name="batch_size"), + ], + hidden=[IO.Hidden.unique_id], + ) - RETURN_TYPES = (IO.INT, IO.INT, IO.INT) - RETURN_NAMES = ("width", "height", "batch_size") - FUNCTION = "get_size" - - CATEGORY = "image" - DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" - - def get_size(self, image, unique_id=None) -> tuple[int, int]: + @classmethod + def execute(cls, image) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] batch_size = image.shape[0] # Send progress text to display size on the node - if unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) - return width, height, batch_size + return IO.NodeOutput(width, height, batch_size) + + get_size = execute # TODO: remove + + +class ImageRotate(IO.ComfyNode): -class ImageRotate: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "rotate" + def define_schema(cls): + return IO.Schema( + node_id="ImageRotate", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def rotate(self, image, rotation): + @classmethod + def execute(cls, image, rotation) -> IO.NodeOutput: rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -604,41 +551,57 @@ class ImageRotate: rotate_by = 3 image = torch.rot90(image, k=rotate_by, dims=[2, 1]) - return (image,) + return IO.NodeOutput(image) + + rotate = execute # TODO: remove + + +class ImageFlip(IO.ComfyNode): -class ImageFlip: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "flip" + def define_schema(cls): + return IO.Schema( + node_id="ImageFlip", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def flip(self, image, flip_method): + @classmethod + def execute(cls, image, flip_method) -> IO.NodeOutput: if flip_method.startswith("x"): image = torch.flip(image, dims=[1]) elif flip_method.startswith("y"): image = torch.flip(image, dims=[2]) - return (image,) + return IO.NodeOutput(image) -class ImageScaleToMaxDimension: - upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + flip = execute # TODO: remove + + +class ImageScaleToMaxDimension(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "upscale_method": (s.upscale_methods,), - "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return IO.Schema( + node_id="ImageScaleToMaxDimension", + category="image/upscaling", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "upscale_method", + options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], + ), + IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, largest_size): + @classmethod + def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] @@ -655,20 +618,30 @@ class ImageScaleToMaxDimension: samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1, -1) - return (s,) + return IO.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageCrop": ImageCrop, - "RepeatImageBatch": RepeatImageBatch, - "ImageFromBatch": ImageFromBatch, - "ImageAddNoise": ImageAddNoise, - "SaveAnimatedWEBP": SaveAnimatedWEBP, - "SaveAnimatedPNG": SaveAnimatedPNG, - "SaveSVGNode": SaveSVGNode, - "ImageStitch": ImageStitch, - "ResizeAndPadImage": ResizeAndPadImage, - "GetImageSize": GetImageSize, - "ImageRotate": ImageRotate, - "ImageFlip": ImageFlip, - "ImageScaleToMaxDimension": ImageScaleToMaxDimension, -} + upscale = execute # TODO: remove + + +class ImagesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCrop, + RepeatImageBatch, + ImageFromBatch, + ImageAddNoise, + SaveAnimatedWEBP, + SaveAnimatedPNG, + SaveSVGNode, + ImageStitch, + ResizeAndPadImage, + GetImageSize, + ImageRotate, + ImageFlip, + ImageScaleToMaxDimension, + ] + + +async def comfy_entrypoint() -> ImagesExtension: + return ImagesExtension() diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index e439b18ef..2815c5ffc 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -5,6 +5,7 @@ import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io import logging +import math def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode): samples_out["samples"] = torch.narrow(s1, dim, index, amount) return io.NodeOutput(samples_out) +class LatentCutToBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentCutToBatch", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["t", "x", "y"]), + io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, dim, slice_size) -> io.NodeOutput: + samples_out = samples.copy() + + s1 = samples["samples"] + + if "x" in dim: + dim = s1.ndim - 1 + elif "y" in dim: + dim = s1.ndim - 2 + elif "t" in dim: + dim = s1.ndim - 3 + + if dim < 2: + return io.NodeOutput(samples) + + s = s1.movedim(dim, 1) + if s.shape[1] < slice_size: + slice_size = s.shape[1] + elif s.shape[1] % slice_size != 0: + s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size] + new_shape = [-1, slice_size] + list(s.shape[2:]) + samples_out["samples"] = s.reshape(new_shape).movedim(1, dim) + return io.NodeOutput(samples_out) + class LatentBatch(io.ComfyNode): @classmethod def define_schema(cls): @@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension): LatentInterpolate, LatentConcat, LatentCut, + LatentCutToBatch, LatentBatch, LatentBatchSeedBehavior, LatentApplyOperation, diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 2a0cfcf18..1355b3c93 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -348,7 +348,7 @@ class ZImageControlPatch: if self.mask is None: mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] else: - mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") if latent_image is None: latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 34c388a5a..ca2cdeb50 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): io.Image.Input("image"), io.Combo.Input("upscale_method", options=cls.upscale_methods), io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + io.Int.Input("resolution_steps", default=1, min=1, max=256), ], outputs=[ io.Image.Output(), @@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode): ) @classmethod - def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: + def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput: samples = image.movedim(-1,1) - total = int(megapixels * 1024 * 1024) + total = megapixels * 1024 * 1024 scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) + width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps + height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps - s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled") s = s.movedim(1,-1) return io.NodeOutput(s) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 525239ae5..fde8fac9a 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -3,7 +3,9 @@ import comfy.utils import math from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import nodes class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): return io.NodeOutput(conditioning) +class EmptyQwenImageLayeredLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyQwenImageLayeredLatentImage", + display_name="Empty Qwen Image Layered Latent", + category="latent/qwen", + inputs=[ + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + class QwenExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeQwenImageEdit, TextEncodeQwenImageEditPlus, + EmptyQwenImageLayeredLatentImage, ] diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 19b8baaf4..364804205 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont from typing_extensions import override import comfy.samplers +import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management @@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar +class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): + """ + CFGGuider with modifications for training specific logic + """ + def outer_sample( + self, + noise, + latent_image, + sampler, + sigmas, + denoise_mask=None, + callback=None, + disable_pbar=False, + seed=None, + latent_shapes=None, + ): + self.inner_model, self.conds, self.loaded_models = ( + comfy.sampler_helpers.prepare_sampling( + self.model_patcher, + noise.shape, + self.conds, + self.model_options, + force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded + ) + ) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = comfy.sampler_helpers.prepare_mask( + denoise_mask, noise.shape, device + ) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + comfy.samplers.cast_to_load_options( + self.model_options, device=device, dtype=self.model_patcher.model_dtype() + ) + + try: + self.model_patcher.pre_run() + output = self.inner_sample( + noise, + latent_image, + device, + sampler, + sigmas, + denoise_mask, + callback, + disable_pbar, + seed, + latent_shapes=latent_shapes, + ) + finally: + self.model_patcher.cleanup() + + comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.loaded_models + return output + + def make_batch_extra_option_dict(d, indicies, full_size=None): new_dict = {} for k, v in d.items(): @@ -65,6 +128,7 @@ class TrainSampler(comfy.samplers.Sampler): seed=0, training_dtype=torch.bfloat16, real_dataset=None, + bucket_latents=None, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -75,6 +139,28 @@ class TrainSampler(comfy.samplers.Sampler): self.seed = seed self.training_dtype = training_dtype self.real_dataset: list[torch.Tensor] | None = real_dataset + # Bucket mode data + self.bucket_latents: list[torch.Tensor] | None = ( + bucket_latents # list of (Bi, C, Hi, Wi) + ) + # Precompute bucket offsets and weights for sampling + if bucket_latents is not None: + self._init_bucket_data(bucket_latents) + else: + self.bucket_offsets = None + self.bucket_weights = None + self.num_images = None + + def _init_bucket_data(self, bucket_latents): + """Initialize bucket offsets and weights for sampling.""" + self.bucket_offsets = [0] + bucket_sizes = [] + for lat in bucket_latents: + bucket_sizes.append(lat.shape[0]) + self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0]) + self.num_images = self.bucket_offsets[-1] + # Weights for sampling buckets proportional to their size + self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) def fwd_bwd( self, @@ -115,6 +201,108 @@ class TrainSampler(comfy.samplers.Sampler): bwd_loss.backward() return loss + def _generate_batch_sigmas(self, model_wrap, batch_size, device): + """Generate random sigma values for a batch.""" + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(batch_size) + ] + return torch.tensor(batch_sigmas).to(device) + + def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar): + """Execute one training step in bucket mode.""" + # Sample bucket (weighted by size), then sample batch from bucket + bucket_idx = torch.multinomial(self.bucket_weights, 1).item() + bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi) + bucket_size = bucket_latent.shape[0] + bucket_offset = self.bucket_offsets[bucket_idx] + + # Sample indices from this bucket (use all if bucket_size < batch_size) + actual_batch_size = min(self.batch_size, bucket_size) + relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist() + # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) + absolute_indices = [bucket_offset + idx for idx in relative_indices] + + batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, # Use flattened cond with absolute indices + absolute_indices, + extra_args, + self.num_images, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx}) + + def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in standard (non-bucket, non-multi-res) mode.""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in multi-resolution mode (real_dataset is set).""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + def sample( self, model_wrap, @@ -142,70 +330,18 @@ class TrainSampler(comfy.samplers.Sampler): noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( self.seed + i * 1000 ) - indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - if self.real_dataset is None: - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( - batch_latent.device - ) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - batch_latent, - cond, - indicies, - extra_args, - dataset_size, - bwd=True, - ) - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + if self.bucket_latents is not None: + self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar) + elif self.real_dataset is None: + self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) else: - total_loss = 0 - for index in indicies: - single_latent = self.real_dataset[index].to(latent_image) - batch_noise = noisegen.generate_noise( - {"samples": single_latent} - ).to(single_latent.device) - batch_sigmas = ( - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - ) - batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - single_latent, - cond, - [index], - extra_args, - dataset_size, - bwd=False, - ) - total_loss += loss - total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() - if self.loss_callback: - self.loss_callback(total_loss.item()) - pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() - ui_pbar.update(1) + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -283,6 +419,364 @@ def unpatch(m): del m.org_forward +def _process_latents_bucket_mode(latents): + """Process latents for bucket mode training. + + Args: + latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) + + Returns: + list of latent tensors + """ + bucket_latents = [] + for latent_dict in latents: + bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) + return bucket_latents + + +def _process_latents_standard_mode(latents): + """Process latents for standard (non-bucket) mode training. + + Args: + latents: list of latent dicts or single latent dict + + Returns: + Processed latents (tensor or list of tensors) + """ + if len(latents) == 1: + return latents[0]["samples"] # Single latent dict + + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + return latent_list + + +def _process_conditioning(positive): + """Process conditioning - either single list or list of lists. + + Args: + positive: list of conditioning + + Returns: + Flattened conditioning list + """ + if len(positive) == 1: + return positive[0] # Single conditioning list + + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + return flat_positive + + +def _prepare_latents_and_count(latents, dtype, bucket_mode): + """Convert latents to dtype and compute image counts. + + Args: + latents: Latents (tensor, list of tensors, or bucket list) + dtype: Target dtype + bucket_mode: Whether bucket mode is enabled + + Returns: + tuple: (processed_latents, num_images, multi_res) + """ + if bucket_mode: + # In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) + latents = [t.to(dtype) for t in latents] + num_buckets = len(latents) + num_images = sum(t.shape[0] for t in latents) + multi_res = False # Not using multi_res path in bucket mode + + logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + for i, lat in enumerate(latents): + logging.info(f" Bucket {i}: shape {lat.shape}") + return latents, num_images, multi_res + + # Non-bucket mode + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + multi_res = False + else: + logging.error(f"Invalid latents type: {type(latents)}") + num_images = 0 + multi_res = False + + return latents, num_images, multi_res + + +def _validate_and_expand_conditioning(positive, num_images, bucket_mode): + """Validate conditioning count matches image count, expand if needed. + + Args: + positive: Conditioning list + num_images: Number of images + bucket_mode: Whether bucket mode is enabled + + Returns: + Validated/expanded conditioning list + + Raises: + ValueError: If conditioning count doesn't match image count + """ + if bucket_mode: + return positive # Skip validation in bucket mode + + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + return positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) + return positive + + +def _load_existing_lora(existing_lora): + """Load existing LoRA weights if provided. + + Args: + existing_lora: LoRA filename or "[None]" + + Returns: + tuple: (existing_weights dict, existing_steps int) + """ + if existing_lora == "[None]": + return {}, 0 + + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + existing_weights = {} + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + return existing_weights, existing_steps + + +def _create_weight_adapter( + module, module_name, existing_weights, algorithm, lora_dtype, rank +): + """Create a weight adapter for a module with weight. + + Args: + module: The module to create adapter for + module_name: Name of the module + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (train_adapter, lora_params dict) + """ + key = f"{module_name}.weight" + shape = module.weight.shape + lora_params = {} + + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) + + # Try to load existing adapter + existing_adapter = None + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + module_name, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + + if existing_adapter is None: + adapter_cls = adapter_maps[algorithm] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + module.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + + for name, parameter in train_adapter.named_parameters(): + lora_params[f"{module_name}.{name}"] = parameter + + return train_adapter.train().requires_grad_(True), lora_params + else: + # 1D weight - use BiasDiff + diff = torch.nn.Parameter( + torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True) + ) + diff_module = BiasDiff(diff).train().requires_grad_(True) + lora_params[f"{module_name}.diff"] = diff + return diff_module, lora_params + + +def _create_bias_adapter(module, module_name, lora_dtype): + """Create a bias adapter for a module with bias. + + Args: + module: The module with bias + module_name: Name of the module + lora_dtype: dtype for LoRA weights + + Returns: + tuple: (bias_module, lora_params dict) + """ + bias = torch.nn.Parameter( + torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias).train().requires_grad_(True) + lora_params = {f"{module_name}.diff_b": bias} + return bias_module, lora_params + + +def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup all LoRA adapters on the model. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list) + """ + lora_sd = {} + all_weight_adapters = [] + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + key = f"{n}.weight" + mp.add_weight_wrapper(key, adapter) + all_weight_adapters.append(adapter) + + if hasattr(m, "bias") and m.bias is not None: + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + + return lora_sd, all_weight_adapters + + +def _create_optimizer(optimizer_name, parameters, learning_rate): + """Create optimizer based on name. + + Args: + optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop") + parameters: Parameters to optimize + learning_rate: Learning rate + + Returns: + Optimizer instance + """ + if optimizer_name == "Adam": + return torch.optim.Adam(parameters, lr=learning_rate) + elif optimizer_name == "AdamW": + return torch.optim.AdamW(parameters, lr=learning_rate) + elif optimizer_name == "SGD": + return torch.optim.SGD(parameters, lr=learning_rate) + elif optimizer_name == "RMSprop": + return torch.optim.RMSprop(parameters, lr=learning_rate) + + +def _create_loss_function(loss_function_name): + """Create loss function based on name. + + Args: + loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1") + + Returns: + Loss function instance + """ + if loss_function_name == "MSE": + return torch.nn.MSELoss() + elif loss_function_name == "L1": + return torch.nn.L1Loss() + elif loss_function_name == "Huber": + return torch.nn.HuberLoss() + elif loss_function_name == "SmoothL1": + return torch.nn.SmoothL1Loss() + + +def _run_training_loop( + guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res +): + """Execute the training loop. + + Args: + guider: The guider object + train_sampler: The training sampler + latents: Latent tensors + num_images: Number of images + seed: Random seed + bucket_mode: Whether bucket mode is enabled + multi_res: Whether multi-resolution mode is enabled + """ + sigmas = torch.tensor(range(num_images)) + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + + if bucket_mode: + # Use first bucket's first latent as dummy for guider + dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": dummy_latent}), + dummy_latent, + train_sampler, + sigmas, + seed=noise.seed, + ) + elif multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + else: + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + + class TrainLoraNode(io.ComfyNode): @classmethod def define_schema(cls): @@ -385,6 +879,11 @@ class TrainLoraNode(io.ComfyNode): default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), + io.Boolean.Input( + "bucket_mode", + default=False, + tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", + ), ], outputs=[ io.Model.Output( @@ -419,6 +918,7 @@ class TrainLoraNode(io.ComfyNode): algorithm, gradient_checkpointing, existing_lora, + bucket_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -427,215 +927,125 @@ class TrainLoraNode(io.ComfyNode): grad_accumulation_steps = grad_accumulation_steps[0] learning_rate = learning_rate[0] rank = rank[0] - optimizer = optimizer[0] - loss_function = loss_function[0] + optimizer_name = optimizer[0] + loss_function_name = loss_function[0] seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] + bucket_mode = bucket_mode[0] - # Handle latents - either single dict or list of dicts - if len(latents) == 1: - latents = latents[0]["samples"] # Single latent dict + # Process latents based on mode + if bucket_mode: + latents = _process_latents_bucket_mode(latents) else: - latent_list = [] - for latent in latents: - latent = latent["samples"] - bs = latent.shape[0] - if bs != 1: - for sub_latent in latent: - latent_list.append(sub_latent[None]) - else: - latent_list.append(latent) - latents = latent_list + latents = _process_latents_standard_mode(latents) - # Handle conditioning - either single list or list of lists - if len(positive) == 1: - positive = positive[0] # Single conditioning list - else: - # Multiple conditioning lists - flatten - flat_positive = [] - for cond in positive: - if isinstance(cond, list): - flat_positive.extend(cond) - else: - flat_positive.append(cond) - positive = flat_positive + # Process conditioning + positive = _process_conditioning(positive) + # Setup model and dtype mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - # latents here can be list of different size latent or one large batch - if isinstance(latents, list): - all_shapes = set() - latents = [t.to(dtype) for t in latents] - for latent in latents: - all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") - if len(all_shapes) > 1: - multi_res = True - else: - multi_res = False - latents = torch.cat(latents, dim=0) - num_images = len(latents) - elif isinstance(latents, torch.Tensor): - latents = latents.to(dtype) - num_images = latents.shape[0] - else: - logging.error(f"Invalid latents type: {type(latents)}") + # Prepare latents and compute counts + latents, num_images, multi_res = _prepare_latents_and_count( + latents, dtype, bucket_mode + ) - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") - if len(positive) == 1 and num_images > 1: - positive = positive * num_images - elif len(positive) != num_images: - raise ValueError( - f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." - ) + # Validate and expand conditioning + positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) with torch.inference_mode(False): - lora_sd = {} - generator = torch.Generator() - generator.manual_seed(seed) + # Setup models for training + mp.model.requires_grad_(False) # Load existing LoRA weights if provided - existing_weights = {} - existing_steps = 0 - if existing_lora != "[None]": - lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) - # Extract steps from filename like "trained_lora_10_steps_20250225_203716" - existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) - if lora_path: - existing_weights = comfy.utils.load_torch_file(lora_path) + existing_weights, existing_steps = _load_existing_lora(existing_lora) - all_weight_adapters = [] - for n, m in mp.model.named_modules(): - if hasattr(m, "weight_function"): - if m.weight is not None: - key = "{}.weight".format(n) - shape = m.weight.shape - if len(shape) >= 2: - alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get(f"{key}.dora_scale", None) - for adapter_cls in adapters: - existing_adapter = adapter_cls.load( - n, existing_weights, alpha, dora_scale - ) - if existing_adapter is not None: - break - else: - existing_adapter = None - adapter_cls = adapter_maps[algorithm] + # Setup LoRA adapters + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) - if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to( - lora_dtype - ) - else: - # Use LoRA with alpha=1.0 by default - train_adapter = adapter_cls.create_train( - m.weight, rank=rank, alpha=1.0 - ).to(lora_dtype) - for name, parameter in train_adapter.named_parameters(): - lora_sd[f"{n}.{name}"] = parameter + # Create optimizer and loss function + optimizer = _create_optimizer( + optimizer_name, lora_sd.values(), learning_rate + ) + criterion = _create_loss_function(loss_function_name) - mp.add_weight_wrapper(key, train_adapter) - all_weight_adapters.append(train_adapter) - else: - diff = torch.nn.Parameter( - torch.zeros( - m.weight.shape, dtype=lora_dtype, requires_grad=True - ) - ) - diff_module = BiasDiff(diff) - mp.add_weight_wrapper(key, BiasDiff(diff)) - all_weight_adapters.append(diff_module) - lora_sd["{}.diff".format(n)] = diff - if hasattr(m, "bias") and m.bias is not None: - key = "{}.bias".format(n) - bias = torch.nn.Parameter( - torch.zeros( - m.bias.shape, dtype=lora_dtype, requires_grad=True - ) - ) - bias_module = BiasDiff(bias) - lora_sd["{}.diff_b".format(n)] = bias - mp.add_weight_wrapper(key, BiasDiff(bias)) - all_weight_adapters.append(bias_module) - - if optimizer == "Adam": - optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) - elif optimizer == "AdamW": - optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) - elif optimizer == "SGD": - optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) - elif optimizer == "RMSprop": - optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) - - # Setup loss function based on selection - if loss_function == "MSE": - criterion = torch.nn.MSELoss() - elif loss_function == "L1": - criterion = torch.nn.L1Loss() - elif loss_function == "Huber": - criterion = torch.nn.HuberLoss() - elif loss_function == "SmoothL1": - criterion = torch.nn.SmoothL1Loss() - - # setup models + # Setup gradient checkpointing if gradient_checkpointing: for m in find_all_highest_child_module_with_forward( mp.model.diffusion_model ): patch(m) - mp.model.requires_grad_(False) + + torch.cuda.empty_cache() + # With force_full_load=False we should be able to have offloading + # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( [mp], memory_required=1e20, force_full_load=True ) + torch.cuda.empty_cache() - # Setup sampler and guider like in test script + # Setup loss tracking loss_map = {"loss": []} def loss_callback(loss): loss_map["loss"].append(loss) - train_sampler = TrainSampler( - criterion, - optimizer, - loss_callback=loss_callback, - batch_size=batch_size, - grad_acc=grad_accumulation_steps, - total_steps=steps * grad_accumulation_steps, - seed=seed, - training_dtype=dtype, - real_dataset=latents if multi_res else None, - ) - guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) - guider.set_conds(positive) # Set conditioning from input + # Create sampler + if bucket_mode: + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + bucket_latents=latents, + ) + else: + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + real_dataset=latents if multi_res else None, + ) - # Training loop + # Setup guider + guider = TrainGuider(mp) + guider.set_conds(positive) + + # Run training loop try: - # Generate dummy sigmas and noise - sigmas = torch.tensor(range(num_images)) - noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) - if multi_res: - # use first latent as dummy latent if multi_res - latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) - guider.sample( - noise.generate_noise({"samples": latents}), - latents, + _run_training_loop( + guider, train_sampler, - sigmas, - seed=noise.seed, + latents, + num_images, + seed, + bucket_mode, + multi_res, ) finally: for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer + # Finalize adapters for adapter in all_weight_adapters: adapter.requires_grad_(False) @@ -645,7 +1055,7 @@ class TrainLoraNode(io.ComfyNode): return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader(io.ComfyNode): +class LoraModelLoader(io.ComfyNode):# @classmethod def define_schema(cls): return io.Schema( diff --git a/comfyui_version.py b/comfyui_version.py index 2f083edaf..1f28e2407 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.4.0" +__version__ = "0.6.0" diff --git a/main.py b/main.py index 0d02a087b..0e07a95da 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,38 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +if os.name == "nt": + os.environ['MIMALLOC_PURGE_DELAY'] = '0' + +if __name__ == "__main__": + os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' + if args.default_device is not None: + default_dev = args.default_device + devices = list(range(32)) + devices.remove(default_dev) + devices.insert(0, default_dev) + devices = ','.join(map(str, devices)) + os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) + os.environ['HIP_VISIBLE_DEVICES'] = str(devices) + + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) + logging.info("Set cuda device to: {}".format(args.cuda_device)) + + if args.oneapi_device_selector is not None: + os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector + logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) + + if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + + import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + def handle_comfyui_manager_unavailable(): if not args.windows_standalone_build: @@ -137,40 +169,6 @@ import shutil import threading import gc - -if os.name == "nt": - os.environ['MIMALLOC_PURGE_DELAY'] = '0' - -if __name__ == "__main__": - os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' - if args.default_device is not None: - default_dev = args.default_device - devices = list(range(32)) - devices.remove(default_dev) - devices.insert(0, default_dev) - devices = ','.join(map(str, devices)) - os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) - os.environ['HIP_VISIBLE_DEVICES'] = str(devices) - - if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) - logging.info("Set cuda device to: {}".format(args.cuda_device)) - - if args.oneapi_device_selector is not None: - os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector - logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) - - if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" - - import cuda_malloc - if "rocm" in cuda_malloc.get_torch_version_noimport(): - os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD - - if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") diff --git a/manager_requirements.txt b/manager_requirements.txt index 5ef0d3a1d..6585b0c19 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b5 +comfyui_manager==4.0.4 diff --git a/nodes.py b/nodes.py index 9dfe00b10..fe8a10b9c 100644 --- a/nodes.py +++ b/nodes.py @@ -343,7 +343,7 @@ class VAEEncode: CATEGORY = "latent" def encode(self, vae, pixels): - t = vae.encode(pixels[:,:,:,:3]) + t = vae.encode(pixels) return ({"samples":t}, ) class VAEEncodeTiled: @@ -361,7 +361,7 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) + t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) class VAEEncodeForInpaint: @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -980,7 +980,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) diff --git a/pyproject.toml b/pyproject.toml index e4d3d616a..35a268bd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.4.0" +version = "0.6.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 9b9e61683..8b670b813 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.59 +comfyui-frontend-package==1.35.9 +comfyui-workflow-templates==0.7.64 comfyui-embedded-docs==0.3.1 torch torchsde diff --git a/server.py b/server.py index ac4f42222..c27f8be7d 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import time import nodes import folder_paths import execution +from comfy_execution.jobs import JobStatus, get_job, get_all_jobs import uuid import urllib import json @@ -47,6 +48,12 @@ from middleware.cache_middleware import cache_control if args.enable_manager: import comfyui_manager + +def _remove_sensitive_from_queue(queue: list) -> list: + """Remove sensitive data (index 5) from queue item tuples.""" + return [item[:5] for item in queue] + + async def send_socket_catch_exception(function, message): try: await function(message) @@ -694,6 +701,129 @@ class PromptServer(): out[node_class] = node_info(node_class) return web.json_response(out) + @routes.get("/api/jobs") + async def get_jobs(request): + """List all jobs with filtering, sorting, and pagination. + + Query parameters: + status: Filter by status (comma-separated): pending, in_progress, completed, failed + workflow_id: Filter by workflow ID + sort_by: Sort field: created_at (default), execution_duration + sort_order: Sort direction: asc, desc (default) + limit: Max items to return (positive integer) + offset: Items to skip (non-negative integer, default 0) + """ + query = request.rel_url.query + + status_param = query.get('status') + workflow_id = query.get('workflow_id') + sort_by = query.get('sort_by', 'created_at').lower() + sort_order = query.get('sort_order', 'desc').lower() + + status_filter = None + if status_param: + status_filter = [s.strip().lower() for s in status_param.split(',') if s.strip()] + invalid_statuses = [s for s in status_filter if s not in JobStatus.ALL] + if invalid_statuses: + return web.json_response( + {"error": f"Invalid status value(s): {', '.join(invalid_statuses)}. Valid values: {', '.join(JobStatus.ALL)}"}, + status=400 + ) + + if sort_by not in {'created_at', 'execution_duration'}: + return web.json_response( + {"error": "sort_by must be 'created_at' or 'execution_duration'"}, + status=400 + ) + + if sort_order not in {'asc', 'desc'}: + return web.json_response( + {"error": "sort_order must be 'asc' or 'desc'"}, + status=400 + ) + + limit = None + + # If limit is provided, validate that it is a positive integer, else continue without a limit + if 'limit' in query: + try: + limit = int(query.get('limit')) + if limit <= 0: + return web.json_response( + {"error": "limit must be a positive integer"}, + status=400 + ) + except (ValueError, TypeError): + return web.json_response( + {"error": "limit must be an integer"}, + status=400 + ) + + offset = 0 + if 'offset' in query: + try: + offset = int(query.get('offset')) + if offset < 0: + offset = 0 + except (ValueError, TypeError): + return web.json_response( + {"error": "offset must be an integer"}, + status=400 + ) + + running, queued = self.prompt_queue.get_current_queue_volatile() + history = self.prompt_queue.get_history() + + running = _remove_sensitive_from_queue(running) + queued = _remove_sensitive_from_queue(queued) + + jobs, total = get_all_jobs( + running, queued, history, + status_filter=status_filter, + workflow_id=workflow_id, + sort_by=sort_by, + sort_order=sort_order, + limit=limit, + offset=offset + ) + + has_more = (offset + len(jobs)) < total + + return web.json_response({ + 'jobs': jobs, + 'pagination': { + 'offset': offset, + 'limit': limit, + 'total': total, + 'has_more': has_more + } + }) + + @routes.get("/api/jobs/{job_id}") + async def get_job_by_id(request): + """Get a single job by ID.""" + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response( + {"error": "job_id is required"}, + status=400 + ) + + running, queued = self.prompt_queue.get_current_queue_volatile() + history = self.prompt_queue.get_history(prompt_id=job_id) + + running = _remove_sensitive_from_queue(running) + queued = _remove_sensitive_from_queue(queued) + + job = get_job(job_id, running, queued, history) + if job is None: + return web.json_response( + {"error": "Job not found"}, + status=404 + ) + + return web.json_response(job) + @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) @@ -717,9 +847,8 @@ class PromptServer(): async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - remove_sensitive = lambda queue: [x[:5] for x in queue] - queue_info['queue_running'] = remove_sensitive(current_queue[0]) - queue_info['queue_pending'] = remove_sensitive(current_queue[1]) + queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0]) + queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1]) return web.json_response(queue_info) @routes.post("/prompt") diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py index b5a0f022c..5c6a15ac4 100644 --- a/tests-unit/comfy_extras_test/image_stitch_test.py +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -25,7 +25,7 @@ class TestImageStitch: result = node.stitch(image1, "right", True, 0, "white", image2=None) - assert len(result) == 1 + assert len(result.result) == 1 assert torch.equal(result[0], image1) def test_basic_horizontal_stitch_right(self): diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ace0d2279..f73ca7e3c 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -99,6 +99,37 @@ class ComfyClient: with urllib.request.urlopen(url) as response: return json.loads(response.read()) + def get_jobs(self, status=None, limit=None, offset=None, sort_by=None, sort_order=None): + url = "http://{}/api/jobs".format(self.server_address) + params = {} + if status is not None: + params["status"] = status + if limit is not None: + params["limit"] = limit + if offset is not None: + params["offset"] = offset + if sort_by is not None: + params["sort_by"] = sort_by + if sort_order is not None: + params["sort_order"] = sort_order + + if params: + url_values = urllib.parse.urlencode(params) + url = "{}?{}".format(url, url_values) + + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + + def get_job(self, job_id): + url = "http://{}/api/jobs/{}".format(self.server_address, job_id) + try: + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + except urllib.error.HTTPError as e: + if e.code == 404: + return None + raise + def set_test_name(self, name): self.test_name = name @@ -877,3 +908,106 @@ class TestExecution: result = client.get_all_history(max_items=5, offset=len(all_history) - 1) assert len(result) <= 1, "Should return at most 1 item when offset is near end" + + # Jobs API tests + def test_jobs_api_job_structure( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test that job objects have required fields""" + self._create_history_item(client, builder) + + jobs_response = client.get_jobs(status="completed", limit=1) + assert len(jobs_response["jobs"]) > 0, "Should have at least one job" + + job = jobs_response["jobs"][0] + assert "id" in job, "Job should have id" + assert "status" in job, "Job should have status" + assert "create_time" in job, "Job should have create_time" + assert "outputs_count" in job, "Job should have outputs_count" + assert "preview_output" in job, "Job should have preview_output" + + def test_jobs_api_preview_output_structure( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test that preview_output has correct structure""" + self._create_history_item(client, builder) + + jobs_response = client.get_jobs(status="completed", limit=1) + job = jobs_response["jobs"][0] + + if job["preview_output"] is not None: + preview = job["preview_output"] + assert "filename" in preview, "Preview should have filename" + assert "nodeId" in preview, "Preview should have nodeId" + assert "mediaType" in preview, "Preview should have mediaType" + + def test_jobs_api_pagination( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test jobs API pagination""" + for _ in range(5): + self._create_history_item(client, builder) + + first_page = client.get_jobs(limit=2, offset=0) + second_page = client.get_jobs(limit=2, offset=2) + + assert len(first_page["jobs"]) <= 2, "First page should have at most 2 jobs" + assert len(second_page["jobs"]) <= 2, "Second page should have at most 2 jobs" + + first_ids = {j["id"] for j in first_page["jobs"]} + second_ids = {j["id"] for j in second_page["jobs"]} + assert first_ids.isdisjoint(second_ids), "Pages should have different jobs" + + def test_jobs_api_sorting( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test jobs API sorting""" + for _ in range(3): + self._create_history_item(client, builder) + + desc_jobs = client.get_jobs(sort_order="desc") + asc_jobs = client.get_jobs(sort_order="asc") + + if len(desc_jobs["jobs"]) >= 2: + desc_times = [j["create_time"] for j in desc_jobs["jobs"] if j["create_time"]] + asc_times = [j["create_time"] for j in asc_jobs["jobs"] if j["create_time"]] + if len(desc_times) >= 2: + assert desc_times == sorted(desc_times, reverse=True), "Desc should be newest first" + if len(asc_times) >= 2: + assert asc_times == sorted(asc_times), "Asc should be oldest first" + + def test_jobs_api_status_filter( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test jobs API status filtering""" + self._create_history_item(client, builder) + + completed_jobs = client.get_jobs(status="completed") + assert len(completed_jobs["jobs"]) > 0, "Should have completed jobs from history" + + for job in completed_jobs["jobs"]: + assert job["status"] == "completed", "Should only return completed jobs" + + # Pending jobs are transient - just verify filter doesn't error + pending_jobs = client.get_jobs(status="pending") + for job in pending_jobs["jobs"]: + assert job["status"] == "pending", "Should only return pending jobs" + + def test_get_job_by_id( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test getting a single job by ID""" + result = self._create_history_item(client, builder) + prompt_id = result.get_prompt_id() + + job = client.get_job(prompt_id) + assert job is not None, "Should find the job" + assert job["id"] == prompt_id, "Job ID should match" + assert "outputs" in job, "Single job should include outputs" + + def test_get_job_not_found( + self, client: ComfyClient, builder: GraphBuilder + ): + """Test getting a non-existent job returns 404""" + job = client.get_job("nonexistent-job-id") + assert job is None, "Non-existent job should return None" diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py new file mode 100644 index 000000000..918c8080a --- /dev/null +++ b/tests/execution/test_jobs.py @@ -0,0 +1,361 @@ +"""Unit tests for comfy_execution/jobs.py""" + +from comfy_execution.jobs import ( + JobStatus, + is_previewable, + normalize_queue_item, + normalize_history_item, + get_outputs_summary, + apply_sorting, +) + + +class TestJobStatus: + """Test JobStatus constants.""" + + def test_status_values(self): + """Status constants should have expected string values.""" + assert JobStatus.PENDING == 'pending' + assert JobStatus.IN_PROGRESS == 'in_progress' + assert JobStatus.COMPLETED == 'completed' + assert JobStatus.FAILED == 'failed' + + def test_all_contains_all_statuses(self): + """ALL should contain all status values.""" + assert JobStatus.PENDING in JobStatus.ALL + assert JobStatus.IN_PROGRESS in JobStatus.ALL + assert JobStatus.COMPLETED in JobStatus.ALL + assert JobStatus.FAILED in JobStatus.ALL + assert len(JobStatus.ALL) == 4 + + +class TestIsPreviewable: + """Unit tests for is_previewable()""" + + def test_previewable_media_types(self): + """Images, video, audio media types should be previewable.""" + for media_type in ['images', 'video', 'audio']: + assert is_previewable(media_type, {}) is True + + def test_non_previewable_media_types(self): + """Other media types should not be previewable.""" + for media_type in ['latents', 'text', 'metadata', 'files']: + assert is_previewable(media_type, {}) is False + + def test_3d_extensions_previewable(self): + """3D file extensions should be previewable regardless of media_type.""" + for ext in ['.obj', '.fbx', '.gltf', '.glb']: + item = {'filename': f'model{ext}'} + assert is_previewable('files', item) is True + + def test_3d_extensions_case_insensitive(self): + """3D extension check should be case insensitive.""" + item = {'filename': 'MODEL.GLB'} + assert is_previewable('files', item) is True + + def test_video_format_previewable(self): + """Items with video/ format should be previewable.""" + item = {'format': 'video/mp4'} + assert is_previewable('files', item) is True + + def test_audio_format_previewable(self): + """Items with audio/ format should be previewable.""" + item = {'format': 'audio/wav'} + assert is_previewable('files', item) is True + + def test_other_format_not_previewable(self): + """Items with other format should not be previewable.""" + item = {'format': 'application/json'} + assert is_previewable('files', item) is False + + +class TestGetOutputsSummary: + """Unit tests for get_outputs_summary()""" + + def test_empty_outputs(self): + """Empty outputs should return 0 count and None preview.""" + count, preview = get_outputs_summary({}) + assert count == 0 + assert preview is None + + def test_counts_across_multiple_nodes(self): + """Outputs from multiple nodes should all be counted.""" + outputs = { + 'node1': {'images': [{'filename': 'a.png', 'type': 'output'}]}, + 'node2': {'images': [{'filename': 'b.png', 'type': 'output'}]}, + 'node3': {'images': [ + {'filename': 'c.png', 'type': 'output'}, + {'filename': 'd.png', 'type': 'output'} + ]} + } + count, preview = get_outputs_summary(outputs) + assert count == 4 + + def test_skips_animated_key_and_non_list_values(self): + """The 'animated' key and non-list values should be skipped.""" + outputs = { + 'node1': { + 'images': [{'filename': 'test.png', 'type': 'output'}], + 'animated': [True], # Should skip due to key name + 'metadata': 'string', # Should skip due to non-list + 'count': 42 # Should skip due to non-list + } + } + count, preview = get_outputs_summary(outputs) + assert count == 1 + + def test_preview_prefers_type_output(self): + """Items with type='output' should be preferred for preview.""" + outputs = { + 'node1': { + 'images': [ + {'filename': 'temp.png', 'type': 'temp'}, + {'filename': 'output.png', 'type': 'output'} + ] + } + } + count, preview = get_outputs_summary(outputs) + assert count == 2 + assert preview['filename'] == 'output.png' + + def test_preview_fallback_when_no_output_type(self): + """If no type='output', should use first previewable.""" + outputs = { + 'node1': { + 'images': [ + {'filename': 'temp1.png', 'type': 'temp'}, + {'filename': 'temp2.png', 'type': 'temp'} + ] + } + } + count, preview = get_outputs_summary(outputs) + assert preview['filename'] == 'temp1.png' + + def test_non_previewable_media_types_counted_but_no_preview(self): + """Non-previewable media types should be counted but not used as preview.""" + outputs = { + 'node1': { + 'latents': [ + {'filename': 'latent1.safetensors'}, + {'filename': 'latent2.safetensors'} + ] + } + } + count, preview = get_outputs_summary(outputs) + assert count == 2 + assert preview is None + + def test_previewable_media_types(self): + """Images, video, and audio media types should be previewable.""" + for media_type in ['images', 'video', 'audio']: + outputs = { + 'node1': { + media_type: [{'filename': 'test.file', 'type': 'output'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview is not None, f"{media_type} should be previewable" + + def test_3d_files_previewable(self): + """3D file extensions should be previewable.""" + for ext in ['.obj', '.fbx', '.gltf', '.glb']: + outputs = { + 'node1': { + 'files': [{'filename': f'model{ext}', 'type': 'output'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview is not None, f"3D file {ext} should be previewable" + + def test_format_mime_type_previewable(self): + """Files with video/ or audio/ format should be previewable.""" + for fmt in ['video/x-custom', 'audio/x-custom']: + outputs = { + 'node1': { + 'files': [{'filename': 'file.custom', 'format': fmt, 'type': 'output'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview is not None, f"Format {fmt} should be previewable" + + def test_preview_enriched_with_node_metadata(self): + """Preview should include nodeId, mediaType, and original fields.""" + outputs = { + 'node123': { + 'images': [{'filename': 'test.png', 'type': 'output', 'subfolder': 'outputs'}] + } + } + count, preview = get_outputs_summary(outputs) + assert preview['nodeId'] == 'node123' + assert preview['mediaType'] == 'images' + assert preview['subfolder'] == 'outputs' + + +class TestApplySorting: + """Unit tests for apply_sorting()""" + + def test_sort_by_create_time_desc(self): + """Default sort by create_time descending.""" + jobs = [ + {'id': 'a', 'create_time': 100}, + {'id': 'b', 'create_time': 300}, + {'id': 'c', 'create_time': 200}, + ] + result = apply_sorting(jobs, 'created_at', 'desc') + assert [j['id'] for j in result] == ['b', 'c', 'a'] + + def test_sort_by_create_time_asc(self): + """Sort by create_time ascending.""" + jobs = [ + {'id': 'a', 'create_time': 100}, + {'id': 'b', 'create_time': 300}, + {'id': 'c', 'create_time': 200}, + ] + result = apply_sorting(jobs, 'created_at', 'asc') + assert [j['id'] for j in result] == ['a', 'c', 'b'] + + def test_sort_by_execution_duration(self): + """Sort by execution_duration should order by duration.""" + jobs = [ + {'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, # 5s + {'id': 'b', 'create_time': 300, 'execution_start_time': 300, 'execution_end_time': 1300}, # 1s + {'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, # 3s + ] + result = apply_sorting(jobs, 'execution_duration', 'desc') + assert [j['id'] for j in result] == ['a', 'c', 'b'] + + def test_sort_with_none_values(self): + """Jobs with None values should sort as 0.""" + jobs = [ + {'id': 'a', 'create_time': 100, 'execution_start_time': 100, 'execution_end_time': 5100}, + {'id': 'b', 'create_time': 300, 'execution_start_time': None, 'execution_end_time': None}, + {'id': 'c', 'create_time': 200, 'execution_start_time': 200, 'execution_end_time': 3200}, + ] + result = apply_sorting(jobs, 'execution_duration', 'asc') + assert result[0]['id'] == 'b' # None treated as 0, comes first + + +class TestNormalizeQueueItem: + """Unit tests for normalize_queue_item()""" + + def test_basic_normalization(self): + """Queue item should be normalized to job dict.""" + item = ( + 10, # priority/number + 'prompt-123', # prompt_id + {'nodes': {}}, # prompt + { + 'create_time': 1234567890, + 'extra_pnginfo': {'workflow': {'id': 'workflow-abc'}} + }, # extra_data + ['node1'], # outputs_to_execute + ) + job = normalize_queue_item(item, JobStatus.PENDING) + + assert job['id'] == 'prompt-123' + assert job['status'] == 'pending' + assert job['priority'] == 10 + assert job['create_time'] == 1234567890 + assert 'execution_start_time' not in job + assert 'execution_end_time' not in job + assert 'execution_error' not in job + assert 'preview_output' not in job + assert job['outputs_count'] == 0 + assert job['workflow_id'] == 'workflow-abc' + + +class TestNormalizeHistoryItem: + """Unit tests for normalize_history_item()""" + + def test_completed_job(self): + """Completed history item should have correct status and times from messages.""" + history_item = { + 'prompt': ( + 5, # priority + 'prompt-456', + {'nodes': {}}, + { + 'create_time': 1234567890000, + 'extra_pnginfo': {'workflow': {'id': 'workflow-xyz'}} + }, + ['node1'], + ), + 'status': { + 'status_str': 'success', + 'completed': True, + 'messages': [ + ('execution_start', {'prompt_id': 'prompt-456', 'timestamp': 1234567890500}), + ('execution_success', {'prompt_id': 'prompt-456', 'timestamp': 1234567893000}), + ] + }, + 'outputs': {}, + } + job = normalize_history_item('prompt-456', history_item) + + assert job['id'] == 'prompt-456' + assert job['status'] == 'completed' + assert job['priority'] == 5 + assert job['execution_start_time'] == 1234567890500 + assert job['execution_end_time'] == 1234567893000 + assert job['workflow_id'] == 'workflow-xyz' + + def test_failed_job(self): + """Failed history item should have failed status and error from messages.""" + history_item = { + 'prompt': ( + 5, + 'prompt-789', + {'nodes': {}}, + {'create_time': 1234567890000}, + ['node1'], + ), + 'status': { + 'status_str': 'error', + 'completed': False, + 'messages': [ + ('execution_start', {'prompt_id': 'prompt-789', 'timestamp': 1234567890500}), + ('execution_error', { + 'prompt_id': 'prompt-789', + 'node_id': '5', + 'node_type': 'KSampler', + 'exception_message': 'CUDA out of memory', + 'exception_type': 'RuntimeError', + 'traceback': ['Traceback...', 'RuntimeError: CUDA out of memory'], + 'timestamp': 1234567891000, + }) + ] + }, + 'outputs': {}, + } + + job = normalize_history_item('prompt-789', history_item) + assert job['status'] == 'failed' + assert job['execution_start_time'] == 1234567890500 + assert job['execution_end_time'] == 1234567891000 + assert job['execution_error']['node_id'] == '5' + assert job['execution_error']['node_type'] == 'KSampler' + assert job['execution_error']['exception_message'] == 'CUDA out of memory' + + def test_include_outputs(self): + """When include_outputs=True, should include full output data.""" + history_item = { + 'prompt': ( + 5, + 'prompt-123', + {'nodes': {'1': {}}}, + {'create_time': 1234567890, 'client_id': 'abc'}, + ['node1'], + ), + 'status': {'status_str': 'success', 'completed': True, 'messages': []}, + 'outputs': {'node1': {'images': [{'filename': 'test.png'}]}}, + } + job = normalize_history_item('prompt-123', history_item, include_outputs=True) + + assert 'outputs' in job + assert 'workflow' in job + assert 'execution_status' in job + assert job['outputs'] == {'node1': {'images': [{'filename': 'test.png'}]}} + assert job['workflow'] == { + 'prompt': {'nodes': {'1': {}}}, + 'extra_data': {'create_time': 1234567890, 'client_id': 'abc'}, + }