From eb8737d675022e730364294c395111af3545d523 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 25 Feb 2026 18:30:48 -0800 Subject: [PATCH 01/17] Update requirements.txt (#12642) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fed5df5fd..88a056e5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.39.16 +comfyui-frontend-package==1.39.19 comfyui-workflow-templates==0.9.3 comfyui-embedded-docs==0.4.3 torch From e14b04478c1712ec8417a046832b821af263ea13 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:36:02 -0800 Subject: [PATCH 02/17] Fix LTXAV text enc min length. (#12640) Should have been 1024 instead of 512 --- comfy/text_encoders/lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 64ce64f89..c547c5ee5 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -72,7 +72,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) special_tokens = {"": 262144, "": 106} - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): From 72535316701ea8074b99755194f149a26e88b4c8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 25 Feb 2026 20:13:47 -0800 Subject: [PATCH 03/17] Fix ltxav te mem estimation. (#12643) --- comfy/text_encoders/lt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index c547c5ee5..e86ea9f4e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,6 +6,7 @@ import comfy.text_encoders.genmo import torch import comfy.utils import math +import itertools class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module): constant /= 2.0 token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) - num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - num_tokens = max(num_tokens, 64) + m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs]) + + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m + num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): From 907e5dcbbffab5e7011346af280a428dc40f3136 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 26 Feb 2026 06:38:46 +0200 Subject: [PATCH 04/17] initial FlowRVS support (#12637) --- comfy/ldm/wan/vae.py | 3 ++- comfy/model_base.py | 9 +++++++++ comfy/model_detection.py | 3 +++ comfy/model_sampling.py | 10 ++++++++++ comfy/sd.py | 3 ++- comfy/supported_models.py | 12 +++++++++++- comfy_extras/nodes_model_advanced.py | 4 +++- 7 files changed, 40 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index fd125ceed..7903c7690 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -459,6 +459,7 @@ class WanVAE(nn.Module): attn_scales=[], temperal_downsample=[True, True, False], image_channels=3, + conv_out_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -474,7 +475,7 @@ class WanVAE(nn.Module): attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): diff --git a/comfy/model_base.py b/comfy/model_base.py index 2f49578f6..4e2096d4b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -76,6 +76,7 @@ class ModelType(Enum): FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 + IMG_TO_IMG_FLOW = 11 def model_sampling(model_config, model_type): @@ -108,6 +109,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLOW_COSMOS: c = comfy.model_sampling.COSMOS_RFLOW s = comfy.model_sampling.ModelSamplingCosmosRFlow + elif model_type == ModelType.IMG_TO_IMG_FLOW: + c = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSampling(s, c): pass @@ -1466,6 +1469,12 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN21_FlowRVS(WAN21): + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): + model_config.unet_config["model_type"] = "t2v" + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 30ea03e8e..030ae6980 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -509,6 +509,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if ref_conv_weight is not None: dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + if metadata is not None and "config" in metadata: + dit_config.update(json.loads(metadata["config"]).get("transformer", {})) + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 2a00ed819..13860e6a2 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -83,6 +83,16 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class IMG_TO_IMG_FLOW(CONST): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + return latent_image + + def inverse_noise_scaling(self, sigma, latent): + return 1.0 - latent + class COSMOS_RFLOW: def calculate_input(self, sigma, noise): sigma = (sigma / (sigma + 1)) diff --git a/comfy/sd.py b/comfy/sd.py index 69d4531e3..de119eb8e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -694,8 +694,9 @@ class VAE: self.latent_dim = 3 self.latent_channels = 16 self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.conv_out_channels = sd["decoder.head.2.weight"].shape[0] self.pad_channel_value = 1.0 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c28be1716..6d08ff0a5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1256,6 +1256,16 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class WAN21_FlowRVS(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "flow_rvs", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1667,6 +1677,6 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 9601a5f76..8bf6a1afa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -52,7 +52,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],), "zsnr": ("BOOLEAN", {"default": False, "advanced": True}), }} @@ -76,6 +76,8 @@ class ModelSamplingDiscrete: sampling_type = comfy.model_sampling.X0 elif sampling == "img_to_img": sampling_type = comfy.model_sampling.IMG_TO_IMG + elif sampling == "img_to_img_flow": + sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSamplingAdvanced(sampling_base, sampling_type): pass From a4522017c518d1f0c3c5d2a803a2d31265da5cd4 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Thu, 26 Feb 2026 08:25:23 +0200 Subject: [PATCH 05/17] feat: per-guide attention strength control in self-attention (#12518) Implements per-guide attention attenuation via log-space additive bias in self-attention. Each guide reference tracks its own strength and optional spatial mask in conditioning metadata (guide_attention_entries). --- comfy/ldm/lightricks/av_model.py | 9 +- comfy/ldm/lightricks/model.py | 264 ++++++++++++++++++++++++++++++- comfy/model_base.py | 44 ++++++ comfy_extras/nodes_lt.py | 47 +++++- 4 files changed, 352 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 2b080aaeb..553fd5b38 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module): def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, - v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module): vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa del vshift_msa, vscale_msa - attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) + attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options) del norm_vx # video cross-attention vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] @@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel): return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] def _process_transformer_blocks( - self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs ): vx = x[0] ax = x[1] @@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=args["v_cross_gate_timestep"], a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], + self_attention_mask=args.get("self_attention_mask"), ) return out @@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel): "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, + "self_attention_mask": self_attention_mask, }, {"original_block": block_wrap}, ) @@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index d61e19d6e..60d760d29 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum import functools +import logging import math from typing import Dict, Optional, Tuple @@ -14,6 +15,8 @@ import comfy.ldm.common_dit from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +logger = logging.getLogger(__name__) + def _log_base(x, base): return np.log(x) / np.log(base) @@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) attn1_input = comfy.ldm.common_dit.rms_norm(x) attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) x.addcmul_(attn1_input, gate_msa) del attn1_input @@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC): """Process input data. Must be implemented by subclasses.""" pass + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Base implementation returns None (no attenuation). Subclasses that + support guide-based attention control should override this. + """ + return None + @abstractmethod - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs): """Process transformer blocks. Must be implemented by subclasses.""" pass @@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC): attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + # Build self-attention mask for per-guide attenuation + self_attention_mask = self._build_guide_self_attention_mask( + x, transformer_options, merged_args + ) + # Process transformer blocks x = self._process_transformer_blocks( - x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + x, context, attention_mask, timestep, pe, + transformer_options=transformer_options, + self_attention_mask=self_attention_mask, + **merged_args, ) # Process output @@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel): pixel_coords = pixel_coords[:, :, grid_mask, ...] kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + + # Compute per-guide surviving token counts from guide_attention_entries. + # Each entry tracks one guide reference; they are appended in order and + # their pre_filter_counts partition the kf_grid_mask. + guide_entries = kwargs.get("guide_attention_entries", None) + if guide_entries: + total_pfc = sum(e["pre_filter_count"] for e in guide_entries) + if total_pfc != len(kf_grid_mask): + raise ValueError( + f"guide pre_filter_counts ({total_pfc}) != " + f"keyframe grid mask length ({len(kf_grid_mask)})" + ) + resolved_entries = [] + offset = 0 + for entry in guide_entries: + pfc = entry["pre_filter_count"] + entry_mask = kf_grid_mask[offset:offset + pfc] + surviving = int(entry_mask.sum().item()) + resolved_entries.append({ + **entry, + "surviving_count": surviving, + }) + offset += pfc + additional_args["resolved_guide_entries"] = resolved_entries + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + # Total surviving guide tokens (all guides) + additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] + x = self.patchify_proj(x) return x, pixel_coords, additional_args - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Reads resolved_guide_entries from merged_args (computed in _process_input) + to build a log-space additive bias mask that attenuates noisy ↔ guide + attention for each guide reference independently. + + Returns None if no attenuation is needed (all strengths == 1.0 and no + spatial masks, or no guide tokens). + """ + if isinstance(x, list): + # AV model: x = [vx, ax]; use vx for token count and device + total_tokens = x[0].shape[1] + device = x[0].device + dtype = x[0].dtype + else: + total_tokens = x.shape[1] + device = x.device + dtype = x.dtype + + num_guide_tokens = merged_args.get("num_guide_tokens", 0) + if num_guide_tokens == 0: + return None + + resolved_entries = merged_args.get("resolved_guide_entries", None) + if not resolved_entries: + return None + + # Check if any attenuation is actually needed + needs_attenuation = any( + e["strength"] < 1.0 or e.get("pixel_mask") is not None + for e in resolved_entries + ) + if not needs_attenuation: + return None + + # Build per-guide-token weights for all tracked guide tokens. + # Guides are appended in order at the end of the sequence. + guide_start = total_tokens - num_guide_tokens + all_weights = [] + total_tracked = 0 + + for entry in resolved_entries: + surviving = entry["surviving_count"] + if surviving == 0: + continue + + strength = entry["strength"] + pixel_mask = entry.get("pixel_mask") + latent_shape = entry.get("latent_shape") + + if pixel_mask is not None and latent_shape is not None: + f_lat, h_lat, w_lat = latent_shape + per_token = self._downsample_mask_to_latent( + pixel_mask.to(device=device, dtype=dtype), + f_lat, h_lat, w_lat, + ) + # per_token shape: (B, f_lat*h_lat*w_lat). + # Collapse batch dim — the mask is assumed identical across the + # batch; validate and take the first element to get (1, tokens). + if per_token.shape[0] > 1: + ref = per_token[0] + for bi in range(1, per_token.shape[0]): + if not torch.equal(ref, per_token[bi]): + logger.warning( + "pixel_mask differs across batch elements; " + "using first element only." + ) + break + per_token = per_token[:1] + # `surviving` is the post-grid_mask token count. + # Clamp to surviving to handle any mismatch safely. + n_weights = min(per_token.shape[1], surviving) + weights = per_token[:, :n_weights] * strength # (1, n_weights) + else: + weights = torch.full( + (1, surviving), strength, device=device, dtype=dtype + ) + + all_weights.append(weights) + total_tracked += weights.shape[1] + + if not all_weights: + return None + + # Concatenate per-token weights for all tracked guides + tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) + + # Check if any weight is actually < 1.0 (otherwise no attenuation needed) + if (tracked_weights >= 1.0).all(): + return None + + # Build the mask: guide tokens are at the end of the sequence. + # Tracked guides come first (in order), untracked follow. + return self._build_self_attention_mask( + total_tokens, num_guide_tokens, total_tracked, + tracked_weights, guide_start, device, dtype, + ) + + @staticmethod + def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): + """Downsample a pixel-space mask to per-token latent weights. + + Args: + mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1]. + f_lat: Number of latent frames (pre-dilation original count). + h_lat: Latent height (pre-dilation original height). + w_lat: Latent width (pre-dilation original width). + + Returns: + (B, F_lat * H_lat * W_lat) flattened per-token weights. + """ + b = mask.shape[0] + f_pix = mask.shape[2] + + # Spatial downsampling: area interpolation per frame + spatial_down = torch.nn.functional.interpolate( + rearrange(mask, "b 1 f h w -> (b f) 1 h w"), + size=(h_lat, w_lat), + mode="area", + ) + spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b) + + # Temporal downsampling: first pixel frame maps to first latent frame, + # remaining pixel frames are averaged in groups for causal temporal structure. + first_frame = spatial_down[:, :, :1, :, :] + if f_pix > 1 and f_lat > 1: + remaining_pix = f_pix - 1 + remaining_lat = f_lat - 1 + t = remaining_pix // remaining_lat + if t < 1: + # Fewer pixel frames than latent frames — upsample by repeating + # the available pixel frames via nearest interpolation. + rest_flat = rearrange( + spatial_down[:, :, 1:, :, :], + "b 1 f h w -> (b h w) 1 f", + ) + rest_up = torch.nn.functional.interpolate( + rest_flat, size=remaining_lat, mode="nearest", + ) + rest = rearrange( + rest_up, "(b h w) 1 f -> b 1 f h w", + b=b, h=h_lat, w=w_lat, + ) + else: + # Trim trailing pixel frames that don't fill a complete group + usable = remaining_lat * t + rest = rearrange( + spatial_down[:, :, 1:1 + usable, :, :], + "b 1 (f t) h w -> b 1 f t h w", + t=t, + ) + rest = rest.mean(dim=3) + latent_mask = torch.cat([first_frame, rest], dim=2) + elif f_lat > 1: + # Single pixel frame but multiple latent frames — repeat the + # single frame across all latent frames. + latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1) + else: + latent_mask = first_frame + + return rearrange(latent_mask, "b 1 f h w -> b (f h w)") + + @staticmethod + def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, + tracked_weights, guide_start, device, dtype): + """Build a log-space additive self-attention bias mask. + + Attenuates attention between noisy tokens and tracked guide tokens. + Untracked guide tokens (at the end of the guide portion) keep full attention. + + Args: + total_tokens: Total sequence length. + num_guide_tokens: Total guide tokens (all guides) at end of sequence. + tracked_count: Number of tracked guide tokens (first in the guide portion). + tracked_weights: (1, tracked_count) tensor, values in [0, 1]. + guide_start: Index where guide tokens begin in the sequence. + device: Target device. + dtype: Target dtype. + + Returns: + (1, 1, total_tokens, total_tokens) additive bias mask. + 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. + """ + finfo = torch.finfo(dtype) + mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) + tracked_end = guide_start + tracked_count + + # Convert weights to log-space bias + w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) + log_w = torch.full_like(w, finfo.min) + positive_mask = w > 0 + if positive_mask.any(): + log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) + + # noisy → tracked guides: each noisy row gets the same per-guide weight + mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) + # tracked guides → noisy: each guide row broadcasts its weight across noisy cols + mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) + + return mask + + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel): def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel): timestep=timestep, pe=pe, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 4e2096d4b..04695c079 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -65,6 +65,42 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher + +class _CONDGuideEntries(comfy.conds.CONDConstant): + """CONDConstant subclass that safely compares guide_attention_entries. + + guide_attention_entries may contain ``pixel_mask`` tensors. The default + ``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError`` + on tensors. This subclass performs a structural comparison instead. + """ + + def can_concat(self, other): + if not isinstance(other, _CONDGuideEntries): + return False + a, b = self.cond, other.cond + if len(a) != len(b): + return False + for ea, eb in zip(a, b): + if ea["pre_filter_count"] != eb["pre_filter_count"]: + return False + if ea["strength"] != eb["strength"]: + return False + if ea.get("latent_shape") != eb.get("latent_shape"): + return False + a_has = ea.get("pixel_mask") is not None + b_has = eb.get("pixel_mask") is not None + if a_has != b_has: + return False + if a_has: + pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"] + if pm_a is not pm_b: + if (pm_a.shape != pm_b.shape + or pm_a.device != pm_b.device + or pm_a.dtype != pm_b.dtype + or not torch.equal(pm_a, pm_b)): + return False + return True + class ModelType(Enum): EPS = 1 V_PREDICTION = 2 @@ -974,6 +1010,10 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -1026,6 +1066,10 @@ class LTXAV(BaseModel): if latent_shapes is not None: out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 1eeeec011..32fe921ff 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): + """Append a guide_attention_entry to both positive and negative conditioning. + + Each entry tracks one guide reference for per-reference attention control. + Entries are derived independently from each conditioning to avoid cross-contamination. + """ + new_entry = { + "pre_filter_count": pre_filter_count, + "strength": strength, + "pixel_mask": None, + "latent_shape": latent_shape, + } + results = [] + for cond in (positive, negative): + # Read existing entries from this specific conditioning + existing = [] + for t in cond: + found = t[1].get("guide_attention_entries", None) + if found is not None: + existing = found + break + # Shallow copy and append (no deepcopy needed — entries contain + # only scalars and None for pixel_mask at this call site). + entries = [*existing, new_entry] + results.append(node_helpers.conditioning_set_values( + cond, {"guide_attention_entries": entries} + )) + return results[0], results[1] + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode): scale_factors, ) + # Track this guide for per-reference attention control. + pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] + guide_latent_shape = list(t.shape[2:]) # [F, H, W] + positive, negative = _append_guide_attention_entry( + positive, negative, pre_filter_count, guide_latent_shape, strength=strength, + ) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] - positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) - negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + positive = node_helpers.conditioning_set_values(positive, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) + negative = node_helpers.conditioning_set_values(negative, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) From 8a4d85c708435b47d0570637fdf1e89199702c48 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 25 Feb 2026 22:30:31 -0800 Subject: [PATCH 06/17] Cleanups to the last PR. (#12646) --- comfy/conds.py | 21 ++++++++++++++++++++- comfy/model_base.py | 40 ++-------------------------------------- 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/comfy/conds.py b/comfy/conds.py index 5af3e93ea..55d8cdd78 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -4,6 +4,25 @@ import comfy.utils import logging +def is_equal(x, y): + if torch.is_tensor(x) and torch.is_tensor(y): + return torch.equal(x, y) + elif isinstance(x, dict) and isinstance(y, dict): + if x.keys() != y.keys(): + return False + return all(is_equal(x[k], y[k]) for k in x) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + if type(x) is not type(y) or len(x) != len(y): + return False + return all(is_equal(a, b) for a, b in zip(x, y)) + else: + try: + return x == y + except Exception: + logging.warning("comparison issue with COND") + return False + + class CONDRegular: def __init__(self, cond): self.cond = cond @@ -84,7 +103,7 @@ class CONDConstant(CONDRegular): return self._copy_with(self.cond) def can_concat(self, other): - if self.cond != other.cond: + if not is_equal(self.cond, other.cond): return False return True diff --git a/comfy/model_base.py b/comfy/model_base.py index 04695c079..8f852e3c6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -65,42 +65,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher - -class _CONDGuideEntries(comfy.conds.CONDConstant): - """CONDConstant subclass that safely compares guide_attention_entries. - - guide_attention_entries may contain ``pixel_mask`` tensors. The default - ``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError`` - on tensors. This subclass performs a structural comparison instead. - """ - - def can_concat(self, other): - if not isinstance(other, _CONDGuideEntries): - return False - a, b = self.cond, other.cond - if len(a) != len(b): - return False - for ea, eb in zip(a, b): - if ea["pre_filter_count"] != eb["pre_filter_count"]: - return False - if ea["strength"] != eb["strength"]: - return False - if ea.get("latent_shape") != eb.get("latent_shape"): - return False - a_has = ea.get("pixel_mask") is not None - b_has = eb.get("pixel_mask") is not None - if a_has != b_has: - return False - if a_has: - pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"] - if pm_a is not pm_b: - if (pm_a.shape != pm_b.shape - or pm_a.device != pm_b.device - or pm_a.dtype != pm_b.dtype - or not torch.equal(pm_a, pm_b)): - return False - return True - class ModelType(Enum): EPS = 1 V_PREDICTION = 2 @@ -1012,7 +976,7 @@ class LTXV(BaseModel): guide_attention_entries = kwargs.get("guide_attention_entries", None) if guide_attention_entries is not None: - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) return out @@ -1068,7 +1032,7 @@ class LTXAV(BaseModel): guide_attention_entries = kwargs.get("guide_attention_entries", None) if guide_attention_entries is not None: - out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries) + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) return out From 74b5a337dcc4d6b276e4af45aa8a654c82569072 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 26 Feb 2026 01:00:32 -0800 Subject: [PATCH 07/17] fix: move essentials_category to correct replacement nodes (#12568) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move essentials_category from deprecated/incorrect nodes to their replacements: - ImageBatch → BatchImagesNode (ImageBatch is deprecated) - Blur → removed (should use subgraph blueprint) - GetVideoComponents → Video Slice Amp-Thread-ID: https://ampcode.com/threads/T-019c8340-4da2-723b-a09f-83895c5bbda5 --- comfy_extras/nodes_post_processing.py | 2 +- comfy_extras/nodes_video.py | 2 +- nodes.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c67245e7a..4a0f7141a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -79,7 +79,6 @@ class Blur(io.ComfyNode): node_id="ImageBlur", display_name="Image Blur", category="image/postprocessing", - essentials_category="Image Tools", inputs=[ io.Image.Input("image"), io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), @@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + essentials_category="Image Tools", search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index db7c171a2..5c096c232 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode): search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="image/video", - essentials_category="Video Tools", description="Extracts all components from a video: frames, audio, and framerate.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), @@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode): "start time", ], category="image/video", + essentials_category="Video Tools", inputs=[ io.Video.Input("video"), io.Float.Input( diff --git a/nodes.py b/nodes.py index e2fc20d53..bff073e30 100644 --- a/nodes.py +++ b/nodes.py @@ -1925,7 +1925,6 @@ class ImageInvert: class ImageBatch: SEARCH_ALIASES = ["combine images", "merge images", "stack images"] - ESSENTIALS_CATEGORY = "Image Tools" @classmethod def INPUT_TYPES(s): From 38ca94599f7444f55589308d1cf611fb77f6ca16 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:07:35 +0000 Subject: [PATCH 08/17] pyopengl-accelerate can cause object to be numpy ints instead of bare ints, which the glDeleteTextures function does not accept, explicitly cast to int (#12650) --- comfy_extras/nodes_glsl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 75ffb6d80..6d210b307 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -717,11 +717,11 @@ def _render_shader_batch( gl.glUseProgram(0) for tex in input_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in output_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in ping_pong_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) if fbo is not None: gl.glDeleteFramebuffers(1, [fbo]) for pp_fbo in ping_pong_fbos: From 420e900f692f72a4e0108594a80a3465c036bebe Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 26 Feb 2026 12:19:38 -0800 Subject: [PATCH 09/17] main: load aimdo earlier (#12655) Some custom node packs are naughty, and violate the dont-load-torch-on-load rule. This causes aimdo to lose preference on its allocator hook on linux. Go super early on the aimdo first-stage init before custom nodes are mentioned at all. --- main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 39e605deb..3fe8f0589 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,10 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -173,10 +177,6 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() import comfy.utils From fd41ec97cc2e457f322afdf136fd2f2b2454a240 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 26 Feb 2026 22:52:10 +0200 Subject: [PATCH 10/17] feat(api-nodes): add NanoBanana2 (#12660) --- comfy_api_nodes/apis/gemini.py | 6 + comfy_api_nodes/nodes_bytedance.py | 2 +- comfy_api_nodes/nodes_gemini.py | 203 +++++++++++++++++++++++++++-- 3 files changed, 196 insertions(+), 15 deletions(-) diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 3304d7e76..639035fef 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -127,9 +127,15 @@ class GeminiImageConfig(BaseModel): imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) + + class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index cfc604aa3..6dbd5984e 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -186,7 +186,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 5.0", + display_name="ByteDance Seedream 4.5 & 5.0", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index b69285be5..3fe804e0b 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -29,6 +29,7 @@ from comfy_api_nodes.apis.gemini import ( GeminiRole, GeminiSystemInstructionContent, GeminiTextPart, + GeminiThinkingConfig, Modality, ) from comfy_api_nodes.util import ( @@ -55,6 +56,21 @@ GEMINI_IMAGE_SYS_PROMPT = ( "Prioritize generating the visual representation above any text, formatting, or conversational requests." ) +GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution"]), + expr=""" + ( + $m := widgets.model; + $r := widgets.resolution; + $isFlash := $contains($m, "nano banana 2"); + $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; + $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; + $prices := $isFlash ? $flashPrices : $proPrices; + {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + ) + """, +) + class GeminiModel(str, Enum): """ @@ -229,6 +245,10 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 120.0 + elif response.modelVersion == "gemini-3.1-flash-image-preview": + input_tokens_price = 0.5 + output_text_tokens_price = 3.0 + output_image_tokens_price = 60.0 else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price @@ -686,7 +706,7 @@ class GeminiImage2(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gemini-3-pro-image-preview"], + options=["gemini-3-pro-image-preview", "Nano Banana 2 (Gemini 3.1 Flash Image)"], ), IO.Int.Input( "seed", @@ -750,19 +770,7 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), - expr=""" - ( - $r := widgets.resolution; - ($contains($r,"1k") or $contains($r,"2k")) - ? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}} - : $contains($r,"4k") - ? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}} - : {"type":"text","text":"Token-based"} - ) - """, - ), + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, ) @classmethod @@ -779,6 +787,10 @@ class GeminiImage2(IO.ComfyNode): system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + if response_modalities == "IMAGE+TEXT": + raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.") parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -815,6 +827,168 @@ class GeminiImage2(IO.ComfyNode): return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) +class GeminiNanoBanana2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNanoBanana2", + display_name="Nano Banana 2", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["Nano Banana 2 (Gemini 3.1 Flash Image)"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + # "1:4", + # "4:1", + # "8:1", + # "1:8", + ], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=[ + # "512px", + "1K", + "2K", + "4K", + ], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE"], + advanced=True, + ), + IO.Combo.Input( + "thinking_level", + options=["MINIMAL", "HIGH"], + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + thinking_level: str, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=thinking_level), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + + class GeminiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -822,6 +996,7 @@ class GeminiExtension(ComfyExtension): GeminiNode, GeminiImage, GeminiImage2, + GeminiNanoBanana2, GeminiInputFiles, ] From 88d05fe483a9f420a69d681e615422930404292b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 27 Feb 2026 04:52:45 +0800 Subject: [PATCH 11/17] chore: update workflow templates to v0.9.4 (#12664) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 88a056e5c..b5b292980 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.3 +comfyui-workflow-templates==0.9.4 comfyui-embedded-docs==0.4.3 torch torchsde From 3dd10a59c00248d00f0cb0ab794ff1bb9fb00a5f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 26 Feb 2026 15:59:22 -0500 Subject: [PATCH 12/17] ComfyUI v0.15.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 6dbda1a87..6a35c6de3 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.0" +__version__ = "0.15.1" diff --git a/pyproject.toml b/pyproject.toml index eaf558740..1b2318273 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.0" +version = "0.15.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 3811780e4f73f9dbace01a85e6d97502406f8ccb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 26 Feb 2026 14:12:29 -0800 Subject: [PATCH 13/17] Portable with cu128 isn't useful anymore. (#12666) Users should either use the cu126 one or the regular one (cu130 at the moment) The cu128 portable is still included in the latest github release but I will stop including it as soon as it becomes slightly annoying to deal with. This might happen as soon as next week. --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index d97c6181d..56b7966cf 100644 --- a/README.md +++ b/README.md @@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). - [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? From b233dbe0bc179847b81680e0b59c493a8dc8d9a6 Mon Sep 17 00:00:00 2001 From: fappaz Date: Fri, 27 Feb 2026 12:19:19 +1300 Subject: [PATCH 14/17] feat(ace-step): add ACE-Step 1.5 lycoris key alias mapping for LoKR #12638 (#12665) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index 279cf38bb..f36ddb046 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -337,6 +337,7 @@ def model_lora_keys_unet(model, key_map={}): if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"): key_lora = k[len("diffusion_model.decoder."):-len(".weight")] key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format return key_map From 08b26ed7c2fe43417058f4c6c5934de3cebf3f20 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 26 Feb 2026 15:59:24 -0800 Subject: [PATCH 15/17] bug_report template: Push harder for logs (#12657) We get a lot od bug reports without logs, especially for performance issues. --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 6556677e0..95cc48f88 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -16,7 +16,7 @@ body: ## Very Important - Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. + Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: From c7f7d52b684f661d911f1747bb6954978fa1d1b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 27 Feb 2026 02:59:05 +0200 Subject: [PATCH 16/17] feat: Support SDPose-OOD (#12661) --- .../modules/diffusionmodules/openaimodel.py | 6 + comfy/ldm/modules/sdpose.py | 130 +++ comfy/model_detection.py | 6 +- comfy/supported_models.py | 3 +- comfy_api/latest/_io.py | 5 +- comfy_extras/nodes_sdpose.py | 740 ++++++++++++++++++ nodes.py | 1 + 7 files changed, 888 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/modules/sdpose.py create mode 100644 comfy_extras/nodes_sdpose.py diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..295310df6 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -18,6 +18,8 @@ import comfy.patcher_extension import comfy.ops ops = comfy.ops.disable_weight_init +from ..sdpose import HeatmapHead + class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. @@ -441,6 +443,7 @@ class UNetModel(nn.Module): disable_temporal_crossattention=False, max_ddpm_temb_period=10000, attn_precision=None, + heatmap_head=False, device=None, operations=ops, ): @@ -827,6 +830,9 @@ class UNetModel(nn.Module): #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) + if heatmap_head: + self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations) + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, diff --git a/comfy/ldm/modules/sdpose.py b/comfy/ldm/modules/sdpose.py new file mode 100644 index 000000000..d67b60b76 --- /dev/null +++ b/comfy/ldm/modules/sdpose.py @@ -0,0 +1,130 @@ +import torch +import numpy as np +from scipy.ndimage import gaussian_filter + +class HeatmapHead(torch.nn.Module): + def __init__( + self, + in_channels=640, + out_channels=133, + input_size=(768, 1024), + heatmap_scale=4, + deconv_out_channels=(640,), + deconv_kernel_sizes=(4,), + conv_out_channels=(640,), + conv_kernel_sizes=(1,), + final_layer_kernel_size=1, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale) + self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32) + + # Deconv layers + if deconv_out_channels: + deconv_layers = [] + for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes): + if kernel_size == 4: + padding, output_padding = 1, 0 + elif kernel_size == 3: + padding, output_padding = 1, 1 + elif kernel_size == 2: + padding, output_padding = 0, 0 + else: + raise ValueError(f'Unsupported kernel size {kernel_size}') + + deconv_layers.extend([ + operations.ConvTranspose2d(in_channels, out_ch, kernel_size, + stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.deconv_layers = torch.nn.Sequential(*deconv_layers) + else: + self.deconv_layers = torch.nn.Identity() + + # Conv layers + if conv_out_channels: + conv_layers = [] + for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes): + padding = (kernel_size - 1) // 2 + conv_layers.extend([ + operations.Conv2d(in_channels, out_ch, kernel_size, + stride=1, padding=padding, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.conv_layers = torch.nn.Sequential(*conv_layers) + else: + self.conv_layers = torch.nn.Identity() + + self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype) + + def forward(self, x): # Decode heatmaps to keypoints + heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x))) + heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W) + B, K, H, W = heatmaps_np.shape + + batch_keypoints = [] + batch_scores = [] + + for b in range(B): + hm = heatmaps_np[b].copy() # (K, H, W) + + # --- vectorised argmax --- + flat = hm.reshape(K, -1) + idx = np.argmax(flat, axis=1) + scores = flat[np.arange(K), idx].copy() + y_locs, x_locs = np.unravel_index(idx, (H, W)) + keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space + invalid = scores <= 0. + keypoints[invalid] = -1 + + # --- DARK sub-pixel refinement (UDP) --- + # 1. Gaussian blur with max-preserving normalisation + border = 5 # (kernel-1)//2 for kernel=11 + for k in range(K): + origin_max = np.max(hm[k]) + dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = hm[k].copy() + dr = gaussian_filter(dr, sigma=2.0) + hm[k] = dr[border:-border, border:-border].copy() + cur_max = np.max(hm[k]) + if cur_max > 0: + hm[k] *= origin_max / cur_max + # 2. Log-space for Taylor expansion + np.clip(hm, 1e-3, 50., hm) + np.log(hm, hm) + # 3. Hessian-based Newton step + hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() + index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, K) + index = index.astype(int).reshape(-1, 1) + i_ = hm_pad[index] + ix1 = hm_pad[index + 1] + iy1 = hm_pad[index + W + 2] + ix1y1 = hm_pad[index + W + 3] + ix1_y1_ = hm_pad[index - W - 3] + ix1_ = hm_pad[index - 1] + iy1_ = hm_pad[index - 2 - W] + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1) + + # --- restore to input image space --- + keypoints = keypoints * self.scale_factor + keypoints[invalid] = -1 + + batch_keypoints.append(keypoints) + batch_scores.append(scores) + + return batch_keypoints, batch_scores diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 030ae6980..b4b51b200 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -795,6 +795,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["use_temporal_resblock"] = False unet_config["use_temporal_attention"] = False + heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix) + if heatmap_key in state_dict_keys: + unet_config["heatmap_head"] = True + return unet_config def model_config_from_unet_config(unet_config, state_dict=None): @@ -1015,7 +1019,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], - 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6d08ff0a5..1bb7b7011 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -525,7 +525,8 @@ class LotusD(SD20): } unet_extra_config = { - "num_classes": 'sequential' + "num_classes": 'sequential', + "num_head_channels": 64, } def get_model(self, state_dict, prefix="", device=None): diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 025727071..189d7d9bc 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1224,9 +1224,10 @@ class BoundingBox(ComfyTypeIO): class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, - socketless: bool=True, default: dict=None, component: str=None): + socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None): super().__init__(id, display_name, optional, tooltip, None, default, socketless) self.component = component + self.force_input = force_input if default is None: self.default = {"x": 0, "y": 0, "width": 512, "height": 512} @@ -1234,6 +1235,8 @@ class BoundingBox(ComfyTypeIO): d = super().as_dict() if self.component: d["component"] = self.component + if self.force_input is not None: + d["forceInput"] = self.force_input return d diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py new file mode 100644 index 000000000..71441848e --- /dev/null +++ b/comfy_extras/nodes_sdpose.py @@ -0,0 +1,740 @@ +import torch +import comfy.utils +import numpy as np +import math +import colorsys +from tqdm import tqdm +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_lotus import LotusConditioning + + +def _preprocess_keypoints(kp_raw, sc_raw): + """Insert neck keypoint and remap from MMPose to OpenPose ordering. + + Returns (kp, sc) where kp has shape (134, 2) and sc has shape (134,). + Layout: + 0-17 body (18 kp, OpenPose order) + 18-23 feet (6 kp) + 24-91 face (68 kp) + 92-112 right hand (21 kp) + 113-133 left hand (21 kp) + """ + kp = np.array(kp_raw, dtype=np.float32) + sc = np.array(sc_raw, dtype=np.float32) + if len(kp) >= 17: + neck = (kp[5] + kp[6]) / 2 + neck_score = min(sc[5], sc[6]) if sc[5] > 0.3 and sc[6] > 0.3 else 0 + kp = np.insert(kp, 17, neck, axis=0) + sc = np.insert(sc, 17, neck_score) + mmpose_idx = np.array([17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]) + openpose_idx = np.array([ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]) + tmp_kp, tmp_sc = kp.copy(), sc.copy() + tmp_kp[openpose_idx] = kp[mmpose_idx] + tmp_sc[openpose_idx] = sc[mmpose_idx] + kp, sc = tmp_kp, tmp_sc + return kp, sc + + +def _to_openpose_frames(all_keypoints, all_scores, height, width): + """Convert raw keypoint lists to a list of OpenPose-style frame dicts. + + Each frame dict contains: + canvas_width, canvas_height, people: list of person dicts with keys: + pose_keypoints_2d - 18 body kp as flat [x,y,score,...] (absolute pixels) + foot_keypoints_2d - 6 foot kp as flat [x,y,score,...] (absolute pixels) + face_keypoints_2d - 70 face kp as flat [x,y,score,...] (absolute pixels) + indices 0-67: 68 face landmarks + index 68: right eye (body[14]) + index 69: left eye (body[15]) + hand_right_keypoints_2d - 21 right-hand kp (absolute pixels) + hand_left_keypoints_2d - 21 left-hand kp (absolute pixels) + """ + def _flatten(kp_slice, sc_slice): + return np.stack([kp_slice[:, 0], kp_slice[:, 1], sc_slice], axis=1).flatten().tolist() + + frames = [] + for img_idx in range(len(all_keypoints)): + people = [] + for kp_raw, sc_raw in zip(all_keypoints[img_idx], all_scores[img_idx]): + kp, sc = _preprocess_keypoints(kp_raw, sc_raw) + # 70 face kp = 68 face landmarks + REye (body[14]) + LEye (body[15]) + face_kp = np.concatenate([kp[24:92], kp[[14, 15]]], axis=0) + face_sc = np.concatenate([sc[24:92], sc[[14, 15]]], axis=0) + people.append({ + "pose_keypoints_2d": _flatten(kp[0:18], sc[0:18]), + "foot_keypoints_2d": _flatten(kp[18:24], sc[18:24]), + "face_keypoints_2d": _flatten(face_kp, face_sc), + "hand_right_keypoints_2d": _flatten(kp[92:113], sc[92:113]), + "hand_left_keypoints_2d": _flatten(kp[113:134], sc[113:134]), + }) + frames.append({"canvas_width": width, "canvas_height": height, "people": people}) + return frames + + +class KeypointDraw: + """ + Pose keypoint drawing class that supports both numpy and cv2 backends. + """ + def __init__(self): + try: + import cv2 + self.draw = cv2 + except ImportError: + self.draw = self + + # Hand connections (same for both hands) + self.hand_edges = [ + [0, 1], [1, 2], [2, 3], [3, 4], # thumb + [0, 5], [5, 6], [6, 7], [7, 8], # index + [0, 9], [9, 10], [10, 11], [11, 12], # middle + [0, 13], [13, 14], [14, 15], [15, 16], # ring + [0, 17], [17, 18], [18, 19], [19, 20], # pinky + ] + + # Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed) + self.body_limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], + [1, 16], [16, 18] + ] + + # Colors matching DWPose + self.colors = [ + [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], + [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], + [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85] + ] + + @staticmethod + def circle(canvas_np, center, radius, color, **kwargs): + """Draw a filled circle using NumPy vectorized operations.""" + cx, cy = center + h, w = canvas_np.shape[:2] + + radius_int = int(np.ceil(radius)) + + y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1) + x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1) + + if y_max <= y_min or x_max <= x_min: + return + + y, x = np.ogrid[y_min:y_max, x_min:x_max] + mask = (x - cx)**2 + (y - cy)**2 <= radius**2 + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs): + """Draw line using Bresenham's algorithm with NumPy operations.""" + x0, y0, x1, y1 = *pt1, *pt2 + h, w = canvas_np.shape[:2] + dx, dy = abs(x1 - x0), abs(y1 - y0) + sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1) + err, x, y, line_points = dx - dy, x0, y0, [] + + while True: + line_points.append((x, y)) + if x == x1 and y == y1: + break + e2 = 2 * err + if e2 > -dy: + err, x = err - dy, x + sx + if e2 < dx: + err, y = err + dx, y + sy + + if thickness > 1: + radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5)) + for px, py in line_points: + y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1) + if y_max > y_min and x_max > x_min: + yy, xx = np.ogrid[y_min:y_max, x_min:x_max] + canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color + else: + line_points = np.array(line_points) + valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w) + if (valid_points := line_points[valid]).size: + canvas_np[valid_points[:, 1], valid_points[:, 0]] = color + + @staticmethod + def fillConvexPoly(canvas_np, pts, color, **kwargs): + """Fill polygon using vectorized scanline algorithm.""" + if len(pts) < 3: + return + pts = np.array(pts, dtype=np.int32) + h, w = canvas_np.shape[:2] + y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1) + if y_max <= y_min or x_max <= x_min: + return + yy, xx = np.mgrid[y_min:y_max, x_min:x_max] + mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool) + + for i in range(len(pts)): + p1, p2 = pts[i], pts[(i + 1) % len(pts)] + y1, y2 = p1[1], p2[1] + if y1 == y2: + continue + if y1 > y2: + p1, p2, y1, y2 = p2, p1, p2[1], p1[1] + if not (edge_mask := (yy >= y1) & (yy < y2)).any(): + continue + mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1)) + + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs): + """Python implementation of cv2.ellipse2Poly.""" + axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output + angle = angle % 360 + if arc_start > arc_end: + arc_start, arc_end = arc_end, arc_start + while arc_start < 0: + arc_start, arc_end = arc_start + 360, arc_end + 360 + while arc_end > 360: + arc_end, arc_start = arc_end - 360, arc_start - 360 + if arc_end - arc_start > 360: + arc_start, arc_end = 0, 360 + + angle_rad = math.radians(angle) + alpha, beta = math.cos(angle_rad), math.sin(angle_rad) + pts = [] + for i in range(arc_start, arc_end + delta, delta): + theta_rad = math.radians(min(i, arc_end)) + x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad) + pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))]) + + unique_pts, prev_pt = [], (float('inf'), float('inf')) + for pt in pts: + if (pt_tuple := tuple(pt)) != prev_pt: + unique_pts.append(pt) + prev_pt = pt_tuple + + return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]] + + def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3, + draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3): + """ + Draw wholebody keypoints (134 keypoints after processing) in DWPose style. + + Expected keypoint format (after neck insertion and remapping): + - Body: 0-17 (18 keypoints in OpenPose format, neck at index 1) + - Foot: 18-23 (6 keypoints) + - Face: 24-91 (68 landmarks) + - Right hand: 92-112 (21 keypoints) + - Left hand: 113-133 (21 keypoints) + + Args: + canvas: The canvas to draw on (numpy array) + keypoints: Array of keypoint coordinates + scores: Optional confidence scores for each keypoint + threshold: Minimum confidence threshold for drawing keypoints + + Returns: + canvas: The canvas with keypoints drawn + """ + H, W, C = canvas.shape + + # Draw body limbs + if draw_body and len(keypoints) >= 18: + for i, limb in enumerate(self.body_limbSeq): + # Convert from 1-indexed to 0-indexed + idx1, idx2 = limb[0] - 1, limb[1] - 1 + + if idx1 >= 18 or idx2 >= 18: + continue + + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + Y = [keypoints[idx1][0], keypoints[idx2][0]] + X = [keypoints[idx1][1], keypoints[idx2][1]] + mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2 + length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) + + if length < 1: + continue + + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1) + + self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)]) + + # Draw body keypoints + if draw_body and len(keypoints) >= 18: + for i in range(18): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw foot keypoints (18-23, 6 keypoints) + if draw_feet and len(keypoints) >= 24: + for i in range(18, 24): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw right hand (92-112) + if draw_hands and len(keypoints) >= 113: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 92 + edge[0], 92 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw right hand keypoints + for i in range(92, 113): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw left hand (113-133) + if draw_hands and len(keypoints) >= 134: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 113 + edge[0], 113 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw left hand keypoints + for i in range(113, 134): + if scores is not None and i < len(scores) and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw face keypoints (24-91) - white dots only, no lines + if draw_face and len(keypoints) >= 92: + eps = 0.01 + for i in range(24, 92): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1) + + return canvas + +class SDPoseDrawKeypoints(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseDrawKeypoints", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "pose"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Boolean.Input("draw_body", default=True), + io.Boolean.Input("draw_hands", default=True), + io.Boolean.Input("draw_face", default=True), + io.Boolean.Input("draw_feet", default=False), + io.Int.Input("stick_width", default=4, min=1, max=10, step=1), + io.Int.Input("face_point_size", default=3, min=1, max=10, step=1), + io.Float.Input("score_threshold", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, keypoints, draw_body, draw_hands, draw_face, draw_feet, stick_width, face_point_size, score_threshold) -> io.NodeOutput: + if not keypoints: + return io.NodeOutput(torch.zeros((1, 64, 64, 3), dtype=torch.float32)) + height = keypoints[0]["canvas_height"] + width = keypoints[0]["canvas_width"] + + def _parse(flat, n): + arr = np.array(flat, dtype=np.float32).reshape(n, 3) + return arr[:, :2], arr[:, 2] + + def _zeros(n): + return np.zeros((n, 2), dtype=np.float32), np.zeros(n, dtype=np.float32) + + pose_outputs = [] + drawer = KeypointDraw() + + for frame in tqdm(keypoints, desc="Drawing keypoints on frames"): + canvas = np.zeros((height, width, 3), dtype=np.uint8) + for person in frame["people"]: + body_kp, body_sc = _parse(person["pose_keypoints_2d"], 18) + foot_raw = person.get("foot_keypoints_2d") + foot_kp, foot_sc = _parse(foot_raw, 6) if foot_raw else _zeros(6) + face_kp, face_sc = _parse(person["face_keypoints_2d"], 70) + face_kp, face_sc = face_kp[:68], face_sc[:68] # drop appended eye kp; body already draws them + rhand_kp, rhand_sc = _parse(person["hand_right_keypoints_2d"], 21) + lhand_kp, lhand_sc = _parse(person["hand_left_keypoints_2d"], 21) + + kp = np.concatenate([body_kp, foot_kp, face_kp, rhand_kp, lhand_kp], axis=0) + sc = np.concatenate([body_sc, foot_sc, face_sc, rhand_sc, lhand_sc], axis=0) + + canvas = drawer.draw_wholebody_keypoints( + canvas, kp, sc, + threshold=score_threshold, + draw_body=draw_body, draw_feet=draw_feet, + draw_face=draw_face, draw_hands=draw_hands, + stick_width=stick_width, face_point_size=face_point_size, + ) + pose_outputs.append(canvas) + + pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0) + final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0 + return io.NodeOutput(final_pose_output) + +class SDPoseKeypointExtractor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseKeypointExtractor", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "sdpose"], + description="Extract pose keypoints from images using the SDPose model: https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints", + inputs=[ + io.Model.Input("model"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Int.Input("batch_size", default=16, min=1, max=10000, step=1), + io.BoundingBox.Input("bboxes", optional=True, force_input=True, tooltip="Optional bounding boxes for more accurate detections. Required for multi-person detection."), + ], + outputs=[ + io.Custom("POSE_KEYPOINT").Output("keypoints", tooltip="Keypoints in OpenPose frame format (canvas_width, canvas_height, people)"), + ], + ) + + @classmethod + def execute(cls, model, vae, image, batch_size, bboxes=None) -> io.NodeOutput: + + height, width = image.shape[-3], image.shape[-2] + context = LotusConditioning().execute().result[0] + + # Use output_block_patch to capture the last 640-channel feature + def output_patch(h, hsp, transformer_options): + nonlocal captured_feat + if h.shape[1] == 640: # Capture the features for wholebody + captured_feat = h.clone() + return h, hsp + + model_clone = model.clone() + model_clone.model_options["transformer_options"] = {"patches": {"output_block_patch": [output_patch]}} + + if not hasattr(model.model.diffusion_model, 'heatmap_head'): + raise ValueError("The provided model does not have a heatmap_head. Please use SDPose model from here https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints.") + + head = model.model.diffusion_model.heatmap_head + total_images = image.shape[0] + captured_feat = None + + model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 + model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + + def _run_on_latent(latent_batch): + """Run one forward pass and return (keypoints_list, scores_list) for the batch.""" + nonlocal captured_feat + captured_feat = None + _ = comfy.sample.sample( + model_clone, + noise=torch.zeros_like(latent_batch), + steps=1, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=context, negative=context, + latent_image=latent_batch, disable_noise=True, disable_pbar=True, + ) + return head(captured_feat) # keypoints_batch, scores_batch + + # all_keypoints / all_scores are lists-of-lists: + # outer index = input image index + # inner index = detected person (one per bbox, or one for full-image) + all_keypoints = [] # shape: [n_images][n_persons] + all_scores = [] # shape: [n_images][n_persons] + pbar = comfy.utils.ProgressBar(total_images) + + if bboxes is not None: + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + bboxes = [None] * total_images + # --- bbox-crop mode: one forward pass per crop ------------------------- + for img_idx in tqdm(range(total_images), desc="Extracting keypoints from crops"): + img = image[img_idx:img_idx + 1] # (1, H, W, C) + # Broadcasting: if fewer bbox lists than images, repeat the last one. + img_bboxes = bboxes[min(img_idx, len(bboxes) - 1)] if bboxes else None + + img_keypoints = [] + img_scores = [] + + if img_bboxes: + for bbox in img_bboxes: + x1 = max(0, int(bbox["x"])) + y1 = max(0, int(bbox["y"])) + x2 = min(width, int(bbox["x"] + bbox["width"])) + y2 = min(height, int(bbox["y"] + bbox["height"])) + + if x2 <= x1 or y2 <= y1: + continue + + crop_h_px, crop_w_px = y2 - y1, x2 - x1 + crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) + + # scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size. + scale = min(model_h / crop_h_px, model_w / crop_w_px) + scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale)) + pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2 + + crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) + padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled + crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC + + latent_crop = vae.encode(crop_resized) + kp_batch, sc_batch = _run_on_latent(latent_crop) + kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space + + # remove padding offset, undo scale, offset to full-image coordinates. + kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) + kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1 + kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1 + + img_keypoints.append(kp) + img_scores.append(sc) + else: + # No bboxes for this image – run on the full image + latent_img = vae.encode(img) + kp_batch, sc_batch = _run_on_latent(latent_img) + img_keypoints.append(kp_batch[0]) + img_scores.append(sc_batch[0]) + + all_keypoints.append(img_keypoints) + all_scores.append(img_scores) + pbar.update(1) + + else: # full-image mode, batched + tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints") + for batch_start in range(0, total_images, batch_size): + batch_end = min(batch_start + batch_size, total_images) + latent_batch = vae.encode(image[batch_start:batch_end]) + + kp_batch, sc_batch = _run_on_latent(latent_batch) + + for kp, sc in zip(kp_batch, sc_batch): + all_keypoints.append([kp]) + all_scores.append([sc]) + tqdm_pbar.update(1) + + pbar.update(batch_end - batch_start) + + openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width) + return io.NodeOutput(openpose_frames) + + +def get_face_bboxes(kp2ds, scale, image_shape): + h, w = image_shape + kp2ds_face = kp2ds.copy()[1:] * (w, h) + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + if initial_width <= 0 or initial_height <= 0: + return [0, 0, 0, 0] + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] + +class SDPoseFaceBBoxes(io.ComfyNode): + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseFaceBBoxes", + category="image/preprocessors", + search_aliases=["face bbox", "face bounding box", "pose", "keypoints"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Float.Input("scale", default=1.5, min=1.0, max=10.0, step=0.1, tooltip="Multiplier for the bounding box area around each detected face."), + io.Boolean.Input("force_square", default=True, tooltip="Expand the shorter bbox axis so the crop region is always square."), + ], + outputs=[ + io.BoundingBox.Output("bboxes", tooltip="Face bounding boxes per frame, compatible with SDPoseKeypointExtractor bboxes input."), + ], + ) + + @classmethod + def execute(cls, keypoints, scale, force_square) -> io.NodeOutput: + all_bboxes = [] + for frame in keypoints: + h = frame["canvas_height"] + w = frame["canvas_width"] + frame_bboxes = [] + for person in frame["people"]: + face_flat = person.get("face_keypoints_2d", []) + if not face_flat: + continue + # Parse absolute-pixel face keypoints (70 kp: 68 landmarks + REye + LEye) + face_arr = np.array(face_flat, dtype=np.float32).reshape(-1, 3) + face_xy = face_arr[:, :2] # (70, 2) in absolute pixels + + kp_norm = face_xy / np.array([w, h], dtype=np.float32) + kp_padded = np.vstack([np.zeros((1, 2), dtype=np.float32), kp_norm]) # (71, 2) + + x1, x2, y1, y2 = get_face_bboxes(kp_padded, scale, (h, w)) + if x2 > x1 and y2 > y1: + if force_square: + bw, bh = x2 - x1, y2 - y1 + if bw != bh: + side = max(bw, bh) + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + half = side // 2 + x1 = max(0, cx - half) + y1 = max(0, cy - half) + x2 = min(w, x1 + side) + y2 = min(h, y1 + side) + # Re-anchor if clamped + x1 = max(0, x2 - side) + y1 = max(0, y2 - side) + frame_bboxes.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}) + + all_bboxes.append(frame_bboxes) + + return io.NodeOutput(all_bboxes) + + +class CropByBBoxes(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CropByBBoxes", + category="image/preprocessors", + search_aliases=["crop", "face crop", "bbox crop", "pose", "bounding box"], + description="Crop and resize regions from the input image batch based on provided bounding boxes.", + inputs=[ + io.Image.Input("image"), + io.BoundingBox.Input("bboxes", force_input=True), + io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."), + io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."), + io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."), + ], + outputs=[ + io.Image.Output(tooltip="All crops stacked into a single image batch."), + ], + ) + + @classmethod + def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput: + total_frames = image.shape[0] + img_h = image.shape[1] + img_w = image.shape[2] + num_ch = image.shape[3] + + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + return io.NodeOutput(image) + + crops = [] + + for frame_idx in range(total_frames): + frame_bboxes = bboxes[min(frame_idx, len(bboxes) - 1)] + if not frame_bboxes: + continue + + frame_chw = image[frame_idx].permute(2, 0, 1).unsqueeze(0) # BHWC → BCHW (1, C, H, W) + + # Union all bboxes for this frame into a single crop region + x1 = min(int(b["x"]) for b in frame_bboxes) + y1 = min(int(b["y"]) for b in frame_bboxes) + x2 = max(int(b["x"] + b["width"]) for b in frame_bboxes) + y2 = max(int(b["y"] + b["height"]) for b in frame_bboxes) + + if padding > 0: + x1 = max(0, x1 - padding) + y1 = max(0, y1 - padding) + x2 = min(img_w, x2 + padding) + y2 = min(img_h, y2 + padding) + + x1, x2 = max(0, x1), min(img_w, x2) + y1, y2 = max(0, y1), min(img_h, y2) + + # Fallback for empty/degenerate crops + if x2 <= x1 or y2 <= y1: + fallback_size = int(min(img_h, img_w) * 0.3) + fb_x1 = max(0, (img_w - fallback_size) // 2) + fb_y1 = max(0, int(img_h * 0.1)) + fb_x2 = min(img_w, fb_x1 + fallback_size) + fb_y2 = min(img_h, fb_y1 + fallback_size) + if fb_x2 <= fb_x1 or fb_y2 <= fb_y1: + crops.append(torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)) + continue + x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2 + + crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w) + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + crops.append(resized) + + if not crops: + return io.NodeOutput(image) + + out_images = torch.cat(crops, dim=0).permute(0, 2, 3, 1) # (N, H, W, C) + return io.NodeOutput(out_images) + + +class SDPoseExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SDPoseKeypointExtractor, + SDPoseDrawKeypoints, + SDPoseFaceBBoxes, + CropByBBoxes, + ] + +async def comfy_entrypoint() -> SDPoseExtension: + return SDPoseExtension() diff --git a/nodes.py b/nodes.py index bff073e30..0222ec629 100644 --- a/nodes.py +++ b/nodes.py @@ -2447,6 +2447,7 @@ async def init_builtin_extra_nodes(): "nodes_toolkit.py", "nodes_replacements.py", "nodes_nag.py", + "nodes_sdpose.py", ] import_failed = [] From 35e9fce7756604050f07a05d090e697b81322c44 Mon Sep 17 00:00:00 2001 From: vickytsang Date: Thu, 26 Feb 2026 17:16:12 -0800 Subject: [PATCH 17/17] Enable Pytorch Attention for gfx950 (#12641) --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 1fe56a62b..f73613f17 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -350,7 +350,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD @@ -378,7 +378,7 @@ try: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): if any((a in arch) for a in ["gfx1200", "gfx1201"]):