Merge branch 'master' into offload_to_mmap

This commit is contained in:
Xiaoyu Xu 2025-12-29 15:44:49 +08:00 committed by GitHub
commit 065b463b62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1104 additions and 511 deletions

View File

@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week. - Releases a new stable version (e.g., v0.7.0) roughly every week.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
- Commits outside of the stable release tags may be very unstable and break many custom nodes. - Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release - Serves as the foundation for the desktop release
@ -209,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
### Instructions: ### Instructions:
Git clone this repo. Git clone this repo.

View File

@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC):
# if multiple conds, split based on primary region # if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1: if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in)) region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]] cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in: for actual_cond in cond_in:

View File

@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
def default_noise_sampler(x, seed=None): def default_noise_sampler(x, seed=None):
if seed is not None: if seed is not None:
if x.device == torch.device("cpu"):
seed += 1
generator = torch.Generator(device=x.device) generator = torch.Generator(device=x.device)
generator.manual_seed(seed) generator.manual_seed(seed)
else: else:

View File

@ -491,7 +491,8 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) # This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None: if self.pad_tokens_multiple is not None:
@ -625,7 +626,7 @@ class NextDiT(nn.Module):
if pooled is not None: if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled) pooled = self.clip_text_pooled_proj(pooled)
else: else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))

View File

@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module): class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding( self.timestep_embedder = TimestepEmbedding(
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations operations=operations
) )
def forward(self, timestep, hidden_states): self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb return timesteps_emb
@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24, num_attention_heads: int = 24,
joint_attention_dim: int = 3584, joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index", default_ref_method="index",
image_model=None, image_model=None,
final_layer=True, final_layer=True,
use_additional_t_cond=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.time_text_embed = QwenTimestepProjEmbeddings( self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim, pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations operations=operations
@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device) img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
self, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
def _forward( def _forward(
self, self,
@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps, timesteps,
context, context,
attention_mask=None, attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None, ref_latents=None,
additional_t_cond=None,
transformer_options={}, transformer_options={},
control=None, control=None,
**kwargs **kwargs
@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module):
index = 0 index = 0
ref_method = kwargs.get("ref_latents_method", self.default_ref_method) ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero" timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if index_ref_method: if index_ref_method:
index += 1 index += 1
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else: else:
index = 1 index = 1
h_offset = 0 h_offset = 0
@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None: temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
def __init__(self, def __init__(self,
dim=128, dim=128,
z_dim=4, z_dim=4,
input_channels=3,
dim_mult=[1, 2, 4, 4], dim_mult=[1, 2, 4, 4],
num_res_blocks=2, num_res_blocks=2,
attn_scales=[], attn_scales=[],
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
scale = 1.0 scale = 1.0
# init block # init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
# downsample blocks # downsample blocks
downsamples = [] downsamples = []
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
def __init__(self, def __init__(self,
dim=128, dim=128,
z_dim=4, z_dim=4,
output_channels=3,
dim_mult=[1, 2, 4, 4], dim_mult=[1, 2, 4, 4],
num_res_blocks=2, num_res_blocks=2,
attn_scales=[], attn_scales=[],
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
# output blocks # output blocks
self.head = nn.Sequential( self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(), RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1)) CausalConv3d(out_dim, output_channels, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1 ## conv1
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
num_res_blocks=2, num_res_blocks=2,
attn_scales=[], attn_scales=[],
temperal_downsample=[True, True, False], temperal_downsample=[True, True, False],
image_channels=3,
dropout=0.0): dropout=0.0):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
# modules # modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout) attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout) attn_scales, self.temperal_upsample, dropout)
def encode(self, x): def encode(self, x):

View File

@ -1110,7 +1110,7 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out: if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
if clip_text_pooled is not None: if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)

View File

@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_theta"] = 10000.0 dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0 dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None: if ctd_weight is not None: # NewBie
dit_config["clip_text_dim"] = ctd_weight.shape[0] dit_config["clip_text_dim"] = ctd_weight.shape[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30 dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30 dit_config["n_kv_heads"] = 30
@ -620,6 +621,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero" dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

View File

@ -28,6 +28,7 @@ import weakref
import gc import gc
import os import os
from functools import lru_cache from functools import lru_cache
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@ -49,6 +50,7 @@ def get_offload_reserve_ram_gb():
def get_free_disk(): def get_free_disk():
return psutil.disk_usage("/").free return psutil.disk_usage("/").free
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -355,13 +357,15 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try: try:
if is_amd(): if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try: try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
@ -1055,8 +1059,8 @@ NUM_STREAMS = 0
if args.async_offload is not None: if args.async_offload is not None:
NUM_STREAMS = args.async_offload NUM_STREAMS = args.async_offload
else: else:
# Enable by default on Nvidia # Enable by default on Nvidia and AMD
if is_nvidia(): if is_nvidia() or is_amd():
NUM_STREAMS = 2 NUM_STREAMS = 2
if args.disable_async_offload: if args.disable_async_offload:

View File

@ -984,9 +984,6 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
noise = noise.to(device) noise = noise.to(device)
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
sigmas = sigmas.to(device) sigmas = sigmas.to(device)
@ -1013,6 +1010,24 @@ class CFGGuider:
else: else:
latent_shapes = [latent_image.shape] latent_shapes = [latent_image.shape]
if denoise_mask is not None:
if denoise_mask.is_nested:
denoise_masks = denoise_mask.unbind()
denoise_masks = denoise_masks[:len(latent_shapes)]
else:
denoise_masks = [denoise_mask]
for i in range(len(denoise_masks), len(latent_shapes)):
denoise_masks.append(torch.ones(latent_shapes[i]))
for i in range(len(denoise_masks)):
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
if len(denoise_masks) > 1:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
self.conds = {} self.conds = {}
for k in self.original_conds: for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

View File

@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image import comfy.text_encoders.z_image
import comfy.text_encoders.ovis import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5 import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -321,6 +323,7 @@ class VAE:
self.latent_channels = 4 self.latent_channels = 4
self.latent_dim = 2 self.latent_dim = 2
self.output_channels = 3 self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
@ -435,6 +438,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64 self.latent_channels = 64
self.output_channels = 2 self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048 self.upscale_ratio = 2048
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.latent_dim = 1 self.latent_dim = 1
@ -546,7 +550,9 @@ class VAE:
self.downscale_index_formula = (4, 8, 8) self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3 self.latent_dim = 3
self.latent_channels = 16 self.latent_channels = 16
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
@ -582,6 +588,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8 self.latent_channels = 8
self.output_channels = 2 self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096 self.upscale_ratio = 4096
self.downscale_ratio = 4096 self.downscale_ratio = 4096
self.latent_dim = 2 self.latent_dim = 2
@ -690,17 +697,28 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
if not self.crop_input: if self.crop_input:
return pixels downscale_ratio = self.spacial_compression_encode()
downscale_ratio = self.spacial_compression_encode() dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
dims = pixels.shape[1:-1] if pixels.shape[-1] > self.output_channels:
for d in range(len(dims)): pixels = pixels[..., :self.output_channels]
x = (dims[d] // downscale_ratio) * downscale_ratio elif pixels.shape[-1] < self.output_channels:
x_offset = (dims[d] % downscale_ratio) // 2 if self.pad_channel_value is not None:
if x != dims[d]: if isinstance(self.pad_channel_value, str):
pixels = pixels.narrow(d + 1, x_offset, x) mode = self.pad_channel_value
value = None
else:
mode = "constant"
value = self.pad_channel_value
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -992,6 +1010,7 @@ class CLIPType(Enum):
OVIS = 21 OVIS = 21
KANDINSKY5 = 22 KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23 KANDINSKY5_IMAGE = 23
NEWBIE = 24
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -1022,6 +1041,7 @@ class TEModel(Enum):
MISTRAL3_24B_PRUNED_FLUX2 = 15 MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16 QWEN3_4B = 16
QWEN3_2B = 17 QWEN3_2B = 17
JINA_CLIP_2 = 18
def detect_te_model(sd): def detect_te_model(sd):
@ -1031,6 +1051,8 @@ def detect_te_model(sd):
return TEModel.CLIP_H return TEModel.CLIP_H
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
return TEModel.CLIP_L return TEModel.CLIP_L
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096: if weight.shape[-1] == 4096:
@ -1191,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_2B: elif te_model == TEModel.QWEN3_2B:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
@ -1246,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.KANDINSKY5_IMAGE: elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
clip_data_gemma = clip_data[0]
clip_data_jina = clip_data[1]
else:
clip_data_gemma = clip_data[1]
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@ -513,6 +513,8 @@ class SDTokenizer:
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.embedding_key = embedding_key self.embedding_key = embedding_key
self.disable_weights = disable_weights
def _try_get_embedding(self, embedding_name:str): def _try_get_embedding(self, embedding_name:str):
''' '''
Takes a potential embedding name and tries to retrieve it. Takes a potential embedding name and tries to retrieve it.
@ -547,7 +549,7 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text) text = escape_important(text)
if kwargs.get("disable_weights", False): if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)] parsed_weights = [(text, 1.0)]
else: else:
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)

View File

@ -0,0 +1,219 @@
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
from dataclasses import dataclass
import torch
from torch import nn as nn
from torch.nn import functional as F
import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
vocab_size: int = 250002
type_vocab_size: int = 1
hidden_size: int = 1024
num_hidden_layers: int = 24
num_attention_heads: int = 16
rotary_emb_base: float = 20000.0
intermediate_size: int = 4096
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-05
bos_token_id: int = 0
eos_token_id: int = 2
pad_token_id: int = 1
class XLMRobertaEmbeddings(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
def forward(self, input_ids=None, embeddings=None):
if input_ids is not None and embeddings is None:
embeddings = self.word_embeddings(input_ids)
if embeddings is not None:
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, q, k):
batch, seqlen, heads, head_dim = q.shape
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
def rotate_half(x):
size = x.shape[-1] // 2
x1, x2 = x[..., :size], x[..., size:]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MHA(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = embed_dim // config.num_attention_heads
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None, optimized_attention=None):
qkv = self.Wqkv(x)
batch_size, seq_len, _ = qkv.shape
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k = self.rotary_emb(q, k)
# NHD -> HND
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
return self.out_proj(out)
class MLP(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
self.activation = F.gelu
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class Block(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
def forward(self, hidden_states, mask=None, optimized_attention=None):
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
mlp_out = self.mlp(hidden_states)
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
return hidden_states
class XLMRobertaEncoder(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None):
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
return hidden_states
class XLMRobertaModel_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
x = self.emb_ln(x)
x = self.emb_drop(x)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
sequence_output = self.encoder(x, attention_mask=mask)
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
pooled_output = None
if attention_mask is None:
pooled_output = sequence_output.mean(dim=1)
else:
attention_mask = attention_mask.to(sequence_output.dtype)
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
# Intermediate output is not yet implemented, use None for placeholder
return sequence_output, None, pooled_output
class XLMRobertaModel(nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.config = XLMRobertaConfig(**config_dict)
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
def get_input_embeddings(self):
return self.model.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.model.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class JinaClip2TextModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)

View File

@ -3,7 +3,6 @@ import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any
import math import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management import comfy.model_management
@ -177,7 +176,7 @@ class Gemma3_4B_Config:
num_key_value_heads: int = 4 num_key_value_heads: int = 4
max_position_embeddings: int = 131072 max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6 rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0] rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3" transformer_type: str = "gemma3"
head_dim = 256 head_dim = 256
rms_norm_add = True rms_norm_add = True
@ -186,8 +185,8 @@ class Gemma3_4B_Config:
rope_dims = None rope_dims = None
q_norm = "gemma3" q_norm = "gemma3"
k_norm = "gemma3" k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024] sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [1.0, 8.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) if config.sliding_attention is not None:
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else: else:
self.sliding_attention = False self.sliding_attention = False
@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module):
if self.transformer_type == 'gemma3': if self.transformer_type == 'gemma3':
if self.sliding_attention: if self.sliding_attention:
if x.shape[1] > self.sliding_attention: if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect") sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
else:
attention_mask = sliding_mask
freqs_cis = freqs_cis[1] freqs_cis = freqs_cis[1]
else: else:
freqs_cis = freqs_cis[0] freqs_cis = freqs_cis[0]

View File

@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
class Gemma3_4BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -0,0 +1,62 @@
import torch
import comfy.model_management
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.lumina2
class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
raise NotImplementedError
def state_dict(self):
return {}
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
self.gemma.set_clip_options(options)
self.jina.set_clip_options(options)
def reset_clip_options(self):
self.gemma.reset_clip_options()
self.jina.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_gemma = token_weight_pairs["gemma"]
token_weight_pairs_jina = token_weight_pairs["jina"]
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
return gemma_out, jina_pooled, gemma_extra
def load_sd(self, sd):
if "model.layers.0.self_attn.q_norm.weight" in sd:
return self.gemma.load_sd(sd)
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_

View File

@ -28,9 +28,8 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
prune_dict, shallow_clone_class) prune_dict, shallow_clone_class)
from ._resources import Resources, ResourcesLocal from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL from ._util import MESH, VOXEL, SVG as _SVG
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
class FolderType(str, Enum): class FolderType(str, Enum):
input = "input" input = "input"
@ -656,7 +655,7 @@ class Video(ComfyTypeIO):
@comfytype(io_type="SVG") @comfytype(io_type="SVG")
class SVG(ComfyTypeIO): class SVG(ComfyTypeIO):
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 Type = _SVG
@comfytype(io_type="LORA_MODEL") @comfytype(io_type="LORA_MODEL")
class LoraModel(ComfyTypeIO): class LoraModel(ComfyTypeIO):

View File

@ -1,5 +1,6 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH from .geometry_types import VOXEL, MESH
from .image_types import SVG
__all__ = [ __all__ = [
# Utility Types # Utility Types
@ -8,4 +9,5 @@ __all__ = [
"VideoComponents", "VideoComponents",
"VOXEL", "VOXEL",
"MESH", "MESH",
"SVG",
] ]

View File

@ -0,0 +1,18 @@
from io import BytesIO
class SVG:
"""Stores SVG representations via a list of BytesIO objects."""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)

View File

@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel):
size: str | None = Field(None) size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647) seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0) guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True) watermark: bool | None = Field(False)
class Image2ImageTaskCreationRequest(BaseModel): class Image2ImageTaskCreationRequest(BaseModel):
@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel):
size: str | None = Field("adaptive") size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647) seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0) guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True) watermark: bool | None = Field(False)
class Seedream4Options(BaseModel): class Seedream4Options(BaseModel):
@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel):
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled") sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True) watermark: bool = Field(False)
class ImageTaskCreationResponse(BaseModel): class ImageTaskCreationResponse(BaseModel):

View File

@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel):
systemInstruction: GeminiSystemInstructionContent | None = Field(None) systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None) tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None)
uploadImagesToStorage: bool = Field(True)
class GeminiGenerateContentRequest(BaseModel): class GeminiGenerateContentRequest(BaseModel):

View File

@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
mode: str = Field("pro") mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'") sound: str = Field(..., description="'on' or 'off'")
class MotionControlRequest(BaseModel):
prompt: str = Field(...)
image_url: str = Field(...)
video_url: str = Field(...)
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")

View File

@ -1,10 +1,8 @@
from inspect import cleandoc
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
) )
def convert_mask_to_image(mask: torch.Tensor): def convert_mask_to_image(mask: Input.Image):
""" """
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
""" """
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
class FluxProUltraImageNode(IO.ComfyNode): class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
node_id="FluxProUltraImageNode", node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image", display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False, prompt_upsampling: bool = False,
raw: bool = False, raw: bool = False,
seed: int = 0, seed: int = 0,
image_prompt: torch.Tensor | None = None, image_prompt: Input.Image | None = None,
image_prompt_strength: float = 0.1, image_prompt_strength: float = 0.1,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image_prompt is None: if image_prompt is None:
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
node_id=cls.NODE_ID, node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME, display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
guidance: float, guidance: float,
steps: int, steps: int,
input_image: torch.Tensor | None = None, input_image: Input.Image | None = None,
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
class FluxKontextMaxImageNode(FluxKontextProImageNode): class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
NODE_ID = "FluxKontextMaxImageNode" NODE_ID = "FluxKontextMaxImageNode"
DISPLAY_NAME = "Flux.1 Kontext [max] Image" DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProExpandNode(IO.ComfyNode): class FluxProExpandNode(IO.ComfyNode):
"""
Outpaints image based on prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
node_id="FluxProExpandNode", node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image", display_name="Flux.1 Expand Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Outpaints image based on prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
top: int, top: int,
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
class FluxProFillNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode):
"""
Inpaints image based on mask and prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
node_id="FluxProFillNode", node_id="FluxProFillNode",
display_name="Flux.1 Fill Image", display_name="Flux.1 Fill Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Inpaints image based on mask and prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Mask.Input("mask"), IO.Mask.Input("mask"),
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
mask: torch.Tensor, mask: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
steps: int, steps: int,
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
class Flux2ProImageNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="Flux2ProImageNode", node_id=cls.NODE_ID,
display_name="Flux.2 [pro] Image", display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.", description="Generates images synchronously based on prompt and resolution.",
inputs=[ inputs=[
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"prompt_upsampling", "prompt_upsampling",
default=False, default=True,
tooltip="Whether to perform upsampling on the prompt. " tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, " "If active, automatically modifies the prompt for more creative generation.",
"but results are nondeterministic (same seed will not produce exactly the same result).",
), ),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
], ],
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
hidden=[ hidden=[
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
height: int, height: int,
seed: int, seed: int,
prompt_upsampling: bool, prompt_upsampling: bool,
images: torch.Tensor | None = None, images: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
reference_images = {} reference_images = {}
if images is not None: if images is not None:
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
response_model=BFLFluxProGenerateResponse, response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest( data=Flux2ProGenerateRequest(
prompt=prompt, prompt=prompt,
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
class BFLExtension(ComfyExtension): class BFLExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
Flux2ProImageNode, Flux2ProImageNode,
Flux2MaxImageNode,
] ]

View File

@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the image', tooltip='Whether to add an "AI generated" watermark to the image',
optional=True, optional=True,
), ),
@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the image', tooltip='Whether to add an "AI generated" watermark to the image',
optional=True, optional=True,
), ),
@ -346,7 +346,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the image.', tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True, optional=True,
), ),
@ -380,7 +380,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
sequential_image_generation: str = "disabled", sequential_image_generation: str = "disabled",
max_images: int = 1, max_images: int = 1,
seed: int = 0, seed: int = 0,
watermark: bool = True, watermark: bool = False,
fail_on_partial: bool = True, fail_on_partial: bool = True,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
@ -507,7 +507,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
@ -617,7 +617,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
@ -739,7 +739,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
@ -862,7 +862,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip='Whether to add an "AI generated" watermark to the video.', tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),

View File

@ -34,6 +34,7 @@ from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
audio_to_base64_string, audio_to_base64_string,
bytesio_to_image_tensor, bytesio_to_image_tensor,
download_url_to_image_tensor,
get_number_of_images, get_number_of_images,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
) )
parts = [] parts = []
for part in response.candidates[0].content.parts: for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text: if part_type == "text" and part.text:
parts.append(part) parts.append(part)
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part) parts.append(part)
# Skip parts that don't match the requested type # Skip parts that don't match the requested type
return parts return parts
@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts]) return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = [] image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png") parts = get_parts_by_type(response, "image/png")
for part in parts: for part in parts:
image_data = base64.b64decode(part.inlineData.data) if part.inlineData:
returned_image = bytesio_to_image_tensor(BytesIO(image_data)) image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
else:
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image) image_tensors.append(returned_image)
if len(image_tensors) == 0: if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4)) return torch.zeros((1, 1024, 1024, 4))
@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest( data=GeminiImageGenerateContentRequest(
contents=[ contents=[
GeminiContent(role=GeminiRole.user, parts=parts), GeminiContent(role=GeminiRole.user, parts=parts),
@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode): class GeminiImage2(IO.ComfyNode):
@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
data=GeminiImageGenerateContentRequest( data=GeminiImageGenerateContentRequest(
contents=[ contents=[
GeminiContent(role=GeminiRole.user, parts=parts), GeminiContent(role=GeminiRole.user, parts=parts),
@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension): class GeminiExtension(ComfyExtension):

View File

@ -51,6 +51,7 @@ from comfy_api_nodes.apis import (
) )
from comfy_api_nodes.apis.kling_api import ( from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest, ImageToVideoWithAudioRequest,
MotionControlRequest,
OmniImageParamImage, OmniImageParamImage,
OmniParamImage, OmniParamImage,
OmniParamVideo, OmniParamVideo,
@ -858,7 +859,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="A text prompt describing the video content. " tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.", "This can include both positive and negative descriptions.",
), ),
IO.Combo.Input("duration", options=["5", "10"]), IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"), IO.Image.Input("first_frame"),
IO.Image.Input( IO.Image.Input(
"end_frame", "end_frame",
@ -897,6 +898,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1, max_length=2500) validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None: if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
if duration not in (5, 10) and end_frame is None and reference_images is None:
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [ image_list: list[OmniParamImage] = [
@ -2159,6 +2164,91 @@ class ImageToVideoWithAudio(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class MotionControl(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingMotionControl",
display_name="Kling Motion Control",
category="api node/video/Kling",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"),
IO.Video.Input(
"reference_video",
tooltip="Motion reference video used to drive movement/expression.\n"
"Duration limits depend on character_orientation:\n"
" - image: 310s (max 10s)\n"
" - video: 330s (max 30s)",
),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Combo.Input(
"character_orientation",
options=["video", "image"],
tooltip="Controls where the character's facing/orientation comes from.\n"
"video: movements, expressions, camera moves, and orientation "
"follow the motion reference video (other details via prompt).\n"
"image: movements and expressions still follow the motion reference video, "
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
reference_image: Input.Image,
reference_video: Input.Video,
keep_original_sound: bool,
character_orientation: str,
mode: str,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1))
if character_orientation == "image":
validate_video_duration(reference_video, min_duration=3, max_duration=10)
else:
validate_video_duration(reference_video, min_duration=3, max_duration=30)
validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"),
response_model=TaskStatusResponse,
data=MotionControlRequest(
prompt=prompt,
image_url=(await upload_images_to_comfyapi(cls, reference_image))[0],
video_url=await upload_video_to_comfyapi(cls, reference_video),
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension): class KlingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2184,6 +2274,7 @@ class KlingExtension(ComfyExtension):
OmniProImageNode, OmniProImageNode,
TextToVideoWithAudio, TextToVideoWithAudio,
ImageToVideoWithAudio, ImageToVideoWithAudio,
MotionControl,
] ]

View File

@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1", "Starlight (Astra) Creative": "slc-1",
} }
UPSCALER_VALUES_MAP = {
"FullHD (1080p)": 1920,
"4K (2160p)": 3840,
}
class TopazImageEnhance(IO.ComfyNode): class TopazImageEnhance(IO.ComfyNode):
@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Video.Input("video"), IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True), IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input( IO.Combo.Input(
"upscaler_creativity", "upscaler_creativity",
options=["low", "middle", "high"], options=["low", "middle", "high"],
@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode):
target_frame_rate = src_frame_rate target_frame_rate = src_frame_rate
filters = [] filters = []
if upscaler_enabled: if upscaler_enabled:
target_width = UPSCALER_VALUES_MAP[upscaler_resolution] if "1080p" in upscaler_resolution:
target_height = UPSCALER_VALUES_MAP[upscaler_resolution] target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
filters.append( filters.append(
topaz_api.VideoEnhancementFilter( topaz_api.VideoEnhancementFilter(
model=UPSCALER_MODELS_MAP[upscaler_model], model=UPSCALER_MODELS_MAP[upscaler_model],

View File

@ -168,6 +168,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Only add generateAudio for Veo 3 models # Only add generateAudio for Veo 3 models
if model.find("veo-2.0") == -1: if model.find("veo-2.0") == -1:
parameters["generateAudio"] = generate_audio parameters["generateAudio"] = generate_audio
# force "enhance_prompt" to True for Veo3 models
parameters["enhancePrompt"] = True
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
@ -291,7 +293,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Boolean.Input( IO.Boolean.Input(
"enhance_prompt", "enhance_prompt",
default=True, default=True,
tooltip="Whether to enhance the prompt with AI assistance", tooltip="This parameter is deprecated and ignored.",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(

View File

@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel):
n: int = Field(1, description="Number of images to generate.") # we support only value=1 n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
prompt_extend: bool = Field(True) prompt_extend: bool = Field(True)
watermark: bool = Field(True) watermark: bool = Field(False)
class Image2ImageParametersField(BaseModel): class Image2ImageParametersField(BaseModel):
size: str | None = Field(None) size: str | None = Field(None)
n: int = Field(1, description="Number of images to generate.") # we support only value=1 n: int = Field(1, description="Number of images to generate.") # we support only value=1
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(True) watermark: bool = Field(False)
class Text2VideoParametersField(BaseModel): class Text2VideoParametersField(BaseModel):
@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel):
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15) duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True) prompt_extend: bool = Field(True)
watermark: bool = Field(True) watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.") audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single") shot_type: str = Field("single")
@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel):
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
duration: int = Field(5, ge=5, le=15) duration: int = Field(5, ge=5, le=15)
prompt_extend: bool = Field(True) prompt_extend: bool = Field(True)
watermark: bool = Field(True) watermark: bool = Field(False)
audio: bool = Field(False, description="Whether to generate audio automatically.") audio: bool = Field(False, description="Whether to generate audio automatically.")
shot_type: str = Field("single") shot_type: str = Field("single")
@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip="Whether to add an AI-generated watermark to the result.", tooltip="Whether to add an AI-generated watermark to the result.",
optional=True, optional=True,
), ),
@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode):
height: int = 1024, height: int = 1024,
seed: int = 0, seed: int = 0,
prompt_extend: bool = True, prompt_extend: bool = True,
watermark: bool = True, watermark: bool = False,
): ):
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip="Whether to add an AI-generated watermark to the result.", tooltip="Whether to add an AI-generated watermark to the result.",
optional=True, optional=True,
), ),
@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode):
# width: int = 1024, # width: int = 1024,
# height: int = 1024, # height: int = 1024,
seed: int = 0, seed: int = 0,
watermark: bool = True, watermark: bool = False,
): ):
n_images = get_number_of_images(image) n_images = get_number_of_images(image)
if n_images not in (1, 2): if n_images not in (1, 2):
@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip="Whether to add an AI-generated watermark to the result.", tooltip="Whether to add an AI-generated watermark to the result.",
optional=True, optional=True,
), ),
@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode):
seed: int = 0, seed: int = 0,
generate_audio: bool = False, generate_audio: bool = False,
prompt_extend: bool = True, prompt_extend: bool = True,
watermark: bool = True, watermark: bool = False,
shot_type: str = "single", shot_type: str = "single",
): ):
if "480p" in size and model == "wan2.6-t2v": if "480p" in size and model == "wan2.6-t2v":
@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=False,
tooltip="Whether to add an AI-generated watermark to the result.", tooltip="Whether to add an AI-generated watermark to the result.",
optional=True, optional=True,
), ),
@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode):
seed: int = 0, seed: int = 0,
generate_audio: bool = False, generate_audio: bool = False,
prompt_extend: bool = True, prompt_extend: bool = True,
watermark: bool = True, watermark: bool = False,
shot_type: str = "single", shot_type: str = "single",
): ):
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:

View File

@ -430,9 +430,9 @@ def _display_text(
if status: if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None: if price is not None:
p = f"{float(price):,.4f}".rstrip("0").rstrip(".") p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".")
if p != "0": if p != "0":
display_lines.append(f"Price: ${p}") display_lines.append(f"Price: {p} credits")
if text is not None: if text is not None:
display_lines.append(text) display_lines.append(text)
if display_lines: if display_lines:

View File

@ -9,6 +9,7 @@ import comfy.utils
import node_helpers import node_helpers
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import re
class BasicScheduler(io.ComfyNode): class BasicScheduler(io.ComfyNode):
@ -760,8 +761,12 @@ class SamplerCustom(io.ComfyNode):
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
if "x0" in x0_output: if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
if samples.is_nested:
latent_shapes = [x.shape for x in samples.unbind()]
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
out_denoised = latent.copy() out_denoised = latent.copy()
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) out_denoised["samples"] = x0_out
else: else:
out_denoised = out out_denoised = out
return io.NodeOutput(out, out_denoised) return io.NodeOutput(out, out_denoised)
@ -948,8 +953,12 @@ class SamplerCustomAdvanced(io.ComfyNode):
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
if "x0" in x0_output: if "x0" in x0_output:
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
if samples.is_nested:
latent_shapes = [x.shape for x in samples.unbind()]
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
out_denoised = latent.copy() out_denoised = latent.copy()
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) out_denoised["samples"] = x0_out
else: else:
out_denoised = out out_denoised = out
return io.NodeOutput(out, out_denoised) return io.NodeOutput(out, out_denoised)
@ -1005,6 +1014,25 @@ class AddNoise(io.ComfyNode):
add_noise = execute add_noise = execute
class ManualSigmas(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ManualSigmas",
category="_for_testing/custom_sampling",
is_experimental=True,
inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False)
],
outputs=[io.Sigmas.Output()]
)
@classmethod
def execute(cls, sigmas) -> io.NodeOutput:
sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas)
sigmas = [float(i) for i in sigmas]
sigmas = torch.FloatTensor(sigmas)
return io.NodeOutput(sigmas)
class CustomSamplersExtension(ComfyExtension): class CustomSamplersExtension(ComfyExtension):
@override @override
@ -1044,6 +1072,7 @@ class CustomSamplersExtension(ComfyExtension):
DisableNoise, DisableNoise,
AddNoise, AddNoise,
SamplerCustomAdvanced, SamplerCustomAdvanced,
ManualSigmas,
] ]

View File

@ -2,280 +2,231 @@ from __future__ import annotations
import nodes import nodes
import folder_paths import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json import json
import os import os
import re import re
from io import BytesIO
from inspect import cleandoc
import torch import torch
import comfy.utils import comfy.utils
from comfy.comfy_types import FileLocator, IO
from server import PromptServer from server import PromptServer
from comfy_api.latest import ComfyExtension, IO, UI
from typing_extensions import override
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
MAX_RESOLUTION = nodes.MAX_RESOLUTION MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop: class ImageCrop(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), node_id="ImageCrop",
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), display_name="Image Crop",
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), category="image/transform",
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), inputs=[
}} IO.Image.Input("image"),
RETURN_TYPES = ("IMAGE",) IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
FUNCTION = "crop" IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1) x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1) y = min(y, image.shape[1] - 1)
to_x = width + x to_x = width + x
to_y = height + y to_y = height + y
img = image[:,y:to_y, x:to_x, :] img = image[:,y:to_y, x:to_x, :]
return (img,) return IO.NodeOutput(img)
class RepeatImageBatch: crop = execute # TODO: remove
class RepeatImageBatch(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}), node_id="RepeatImageBatch",
}} category="image/batch",
RETURN_TYPES = ("IMAGE",) inputs=[
FUNCTION = "repeat" IO.Image.Input("image"),
IO.Int.Input("amount", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch" @classmethod
def execute(cls, image, amount) -> IO.NodeOutput:
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1)) s = image.repeat((amount, 1,1,1))
return (s,) return IO.NodeOutput(s)
class ImageFromBatch: repeat = execute # TODO: remove
class ImageFromBatch(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), node_id="ImageFromBatch",
"length": ("INT", {"default": 1, "min": 1, "max": 4096}), category="image/batch",
}} inputs=[
RETURN_TYPES = ("IMAGE",) IO.Image.Input("image"),
FUNCTION = "frombatch" IO.Int.Input("batch_index", default=0, min=0, max=4095),
IO.Int.Input("length", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch" @classmethod
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
def frombatch(self, image, batch_index, length):
s_in = image s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index) batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length) length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone() s = s_in[batch_index:batch_index + length].clone()
return (s,) return IO.NodeOutput(s)
frombatch = execute # TODO: remove
class ImageAddNoise: class ImageAddNoise(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), node_id="ImageAddNoise",
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), category="image",
}} inputs=[
RETURN_TYPES = ("IMAGE",) IO.Image.Input("image"),
FUNCTION = "repeat" IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image" @classmethod
def execute(cls, image, seed, strength) -> IO.NodeOutput:
def repeat(self, image, seed, strength):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
return (s,) return IO.NodeOutput(s)
class SaveAnimatedWEBP: repeat = execute # TODO: remove
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = () class SaveAnimatedWEBP(IO.ComfyNode):
FUNCTION = "save_images" COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
OUTPUT_NODE = True
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results: list[FileLocator] = []
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return IO.Schema(
{"images": ("IMAGE", ), node_id="SaveAnimatedWEBP",
"filename_prefix": ("STRING", {"default": "ComfyUI"}), category="image/animation",
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), inputs=[
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) IO.Image.Input("images"),
}, IO.String.Input("filename_prefix", default="ComfyUI"),
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
} IO.Boolean.Input("lossless", default=True),
IO.Int.Input("quality", default=80, min=0, max=100),
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
RETURN_TYPES = () @classmethod
FUNCTION = "save_images" def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=cls.COMPRESS_METHODS.get(method)
)
)
OUTPUT_NODE = True save_images = execute # TODO: remove
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch: class SaveAnimatedPNG(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAnimatedPNG",
category="image/animation",
inputs=[
IO.Image.Input("images"),
IO.String.Input("filename_prefix", default="ComfyUI"),
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
IO.Int.Input("compress_level", default=4, min=0, max=9),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
compress_level=compress_level,
)
)
save_images = execute # TODO: remove
class ImageStitch(IO.ComfyNode):
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ImageStitch",
"image1": ("IMAGE",), display_name="Image Stitch",
"direction": (["right", "down", "left", "up"], {"default": "right"}), description="Stitches image2 to image1 in the specified direction.\n"
"match_image_size": ("BOOLEAN", {"default": True}), "If image2 is not provided, returns image1 unchanged.\n"
"spacing_width": ( "Optional spacing can be added between images.",
"INT", category="image/transform",
{"default": 0, "min": 0, "max": 1024, "step": 2}, inputs=[
), IO.Image.Input("image1"),
"spacing_color": ( IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
["white", "black", "red", "green", "blue"], IO.Boolean.Input("match_image_size", default=True),
{"default": "white"}, IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
), IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
}, IO.Image.Input("image2", optional=True),
"optional": { ],
"image2": ("IMAGE",), outputs=[IO.Image.Output()],
}, )
}
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "stitch" def execute(
CATEGORY = "image/transform" cls,
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1, image1,
direction, direction,
match_image_size, match_image_size,
spacing_width, spacing_width,
spacing_color, spacing_color,
image2=None, image2=None,
): ) -> IO.NodeOutput:
if image2 is None: if image2 is None:
return (image1,) return IO.NodeOutput(image1)
# Handle batch size differences # Handle batch size differences
if image1.shape[0] != image2.shape[0]: if image1.shape[0] != image2.shape[0]:
@ -412,36 +363,30 @@ Optional spacing can be added between images.
images.insert(1, spacing) images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1 concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),) return IO.NodeOutput(torch.cat(images, dim=concat_dim))
stitch = execute # TODO: remove
class ResizeAndPadImage(IO.ComfyNode):
class ResizeAndPadImage:
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ResizeAndPadImage",
"image": ("IMAGE",), category="image/transform",
"target_width": ("INT", { inputs=[
"default": 512, IO.Image.Input("image"),
"min": 1, IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
"max": MAX_RESOLUTION, IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
"step": 1 IO.Combo.Input("padding_color", options=["white", "black"]),
}), IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
"target_height": ("INT", { ],
"default": 512, outputs=[IO.Image.Output()],
"min": 1, )
"max": MAX_RESOLUTION,
"step": 1
}),
"padding_color": (["white", "black"],),
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
}
}
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "resize_and_pad" def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
CATEGORY = "image/transform"
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
batch_size, orig_height, orig_width, channels = image.shape batch_size, orig_height, orig_width, channels = image.shape
scale_w = target_width / orig_width scale_w = target_width / orig_width
@ -469,52 +414,47 @@ class ResizeAndPadImage:
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
output = padded.permute(0, 2, 3, 1) output = padded.permute(0, 2, 3, 1)
return (output,) return IO.NodeOutput(output)
class SaveSVGNode: resize_and_pad = execute # TODO: remove
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = () class SaveSVGNode(IO.ComfyNode):
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="SaveSVGNode",
"svg": ("SVG",), # Changed description="Save SVG files on disk.",
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) category="image/save",
}, inputs=[
"hidden": { IO.SVG.Input("svg"),
"prompt": "PROMPT", IO.String.Input(
"extra_pnginfo": "EXTRA_PNGINFO" "filename_prefix",
} default="svg/ComfyUI",
} tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): @classmethod
filename_prefix += self.prefix_append def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = list() results: list[UI.SavedResult] = []
# Prepare metadata JSON # Prepare metadata JSON
metadata_dict = {} metadata_dict = {}
if prompt is not None: if cls.hidden.prompt is not None:
metadata_dict["prompt"] = prompt metadata_dict["prompt"] = cls.hidden.prompt
if extra_pnginfo is not None: if cls.hidden.extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo) metadata_dict.update(cls.hidden.extra_pnginfo)
# Convert metadata to JSON string # Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data): for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg" file = f"{filename_with_batch_num}_{counter:05}_.svg"
@ -544,57 +484,64 @@ class SaveSVGNode:
with open(os.path.join(full_output_folder, file), 'wb') as svg_file: with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8')) svg_file.write(svg_content.encode('utf-8'))
results.append({ results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1 counter += 1
return { "ui": { "images": results } } return IO.NodeOutput(ui={"images": results})
class GetImageSize: save_svg = execute # TODO: remove
class GetImageSize(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="GetImageSize",
"image": (IO.IMAGE,), display_name="Get Image Size",
}, description="Returns width and height of the image, and passes it through unchanged.",
"hidden": { category="image",
"unique_id": "UNIQUE_ID", inputs=[
} IO.Image.Input("image"),
} ],
outputs=[
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
IO.Int.Output(display_name="batch_size"),
],
hidden=[IO.Hidden.unique_id],
)
RETURN_TYPES = (IO.INT, IO.INT, IO.INT) @classmethod
RETURN_NAMES = ("width", "height", "batch_size") def execute(cls, image) -> IO.NodeOutput:
FUNCTION = "get_size"
CATEGORY = "image"
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
def get_size(self, image, unique_id=None) -> tuple[int, int]:
height = image.shape[1] height = image.shape[1]
width = image.shape[2] width = image.shape[2]
batch_size = image.shape[0] batch_size = image.shape[0]
# Send progress text to display size on the node # Send progress text to display size on the node
if unique_id: if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
return width, height, batch_size return IO.NodeOutput(width, height, batch_size)
get_size = execute # TODO: remove
class ImageRotate(IO.ComfyNode):
class ImageRotate:
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": (IO.IMAGE,), return IO.Schema(
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), node_id="ImageRotate",
}} category="image/transform",
RETURN_TYPES = (IO.IMAGE,) inputs=[
FUNCTION = "rotate" IO.Image.Input("image"),
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, rotation) -> IO.NodeOutput:
def rotate(self, image, rotation):
rotate_by = 0 rotate_by = 0
if rotation.startswith("90"): if rotation.startswith("90"):
rotate_by = 1 rotate_by = 1
@ -604,41 +551,57 @@ class ImageRotate:
rotate_by = 3 rotate_by = 3
image = torch.rot90(image, k=rotate_by, dims=[2, 1]) image = torch.rot90(image, k=rotate_by, dims=[2, 1])
return (image,) return IO.NodeOutput(image)
rotate = execute # TODO: remove
class ImageFlip(IO.ComfyNode):
class ImageFlip:
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": (IO.IMAGE,), return IO.Schema(
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],), node_id="ImageFlip",
}} category="image/transform",
RETURN_TYPES = (IO.IMAGE,) inputs=[
FUNCTION = "flip" IO.Image.Input("image"),
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, flip_method) -> IO.NodeOutput:
def flip(self, image, flip_method):
if flip_method.startswith("x"): if flip_method.startswith("x"):
image = torch.flip(image, dims=[1]) image = torch.flip(image, dims=[1])
elif flip_method.startswith("y"): elif flip_method.startswith("y"):
image = torch.flip(image, dims=[2]) image = torch.flip(image, dims=[2])
return (image,) return IO.NodeOutput(image)
class ImageScaleToMaxDimension: flip = execute # TODO: remove
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
class ImageScaleToMaxDimension(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE",), return IO.Schema(
"upscale_method": (s.upscale_methods,), node_id="ImageScaleToMaxDimension",
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} category="image/upscaling",
RETURN_TYPES = ("IMAGE",) inputs=[
FUNCTION = "upscale" IO.Image.Input("image"),
IO.Combo.Input(
"upscale_method",
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
),
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/upscaling" @classmethod
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
def upscale(self, image, upscale_method, largest_size):
height = image.shape[1] height = image.shape[1]
width = image.shape[2] width = image.shape[2]
@ -655,20 +618,30 @@ class ImageScaleToMaxDimension:
samples = image.movedim(-1, 1) samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1) s = s.movedim(1, -1)
return (s,) return IO.NodeOutput(s)
NODE_CLASS_MAPPINGS = { upscale = execute # TODO: remove
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch, class ImagesExtension(ComfyExtension):
"ImageAddNoise": ImageAddNoise, @override
"SaveAnimatedWEBP": SaveAnimatedWEBP, async def get_node_list(self) -> list[type[IO.ComfyNode]]:
"SaveAnimatedPNG": SaveAnimatedPNG, return [
"SaveSVGNode": SaveSVGNode, ImageCrop,
"ImageStitch": ImageStitch, RepeatImageBatch,
"ResizeAndPadImage": ResizeAndPadImage, ImageFromBatch,
"GetImageSize": GetImageSize, ImageAddNoise,
"ImageRotate": ImageRotate, SaveAnimatedWEBP,
"ImageFlip": ImageFlip, SaveAnimatedPNG,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension, SaveSVGNode,
} ImageStitch,
ResizeAndPadImage,
GetImageSize,
ImageRotate,
ImageFlip,
ImageScaleToMaxDimension,
]
async def comfy_entrypoint() -> ImagesExtension:
return ImagesExtension()

View File

@ -5,6 +5,7 @@ import nodes
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import logging import logging
import math
def reshape_latent_to(target_shape, latent, repeat_batch=True): def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode):
samples_out["samples"] = torch.narrow(s1, dim, index, amount) samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return io.NodeOutput(samples_out) return io.NodeOutput(samples_out)
class LatentCutToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentCutToBatch",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["t", "x", "y"]),
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
if "x" in dim:
dim = s1.ndim - 1
elif "y" in dim:
dim = s1.ndim - 2
elif "t" in dim:
dim = s1.ndim - 3
if dim < 2:
return io.NodeOutput(samples)
s = s1.movedim(dim, 1)
if s.shape[1] < slice_size:
slice_size = s.shape[1]
elif s.shape[1] % slice_size != 0:
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
new_shape = [-1, slice_size] + list(s.shape[2:])
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
return io.NodeOutput(samples_out)
class LatentBatch(io.ComfyNode): class LatentBatch(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension):
LatentInterpolate, LatentInterpolate,
LatentConcat, LatentConcat,
LatentCut, LatentCut,
LatentCutToBatch,
LatentBatch, LatentBatch,
LatentBatchSeedBehavior, LatentBatchSeedBehavior,
LatentApplyOperation, LatentApplyOperation,

View File

@ -348,7 +348,7 @@ class ZImageControlPatch:
if self.mask is None: if self.mask is None:
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
else: else:
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
if latent_image is None: if latent_image is None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))

View File

@ -3,7 +3,9 @@ import comfy.utils
import math import math
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import nodes
class TextEncodeQwenImageEdit(io.ComfyNode): class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod @classmethod
@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
return io.NodeOutput(conditioning) return io.NodeOutput(conditioning)
class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyQwenImageLayeredLatentImage",
display_name="Empty Qwen Image Layered Latent",
category="latent/qwen",
inputs=[
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class QwenExtension(ComfyExtension): class QwenExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
TextEncodeQwenImageEdit, TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus, TextEncodeQwenImageEditPlus,
EmptyQwenImageLayeredLatentImage,
] ]

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.5.1" __version__ = "0.6.0"

View File

@ -1 +1 @@
comfyui_manager==4.0.3b5 comfyui_manager==4.0.4

View File

@ -343,7 +343,7 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
def encode(self, vae, pixels): def encode(self, vae, pixels):
t = vae.encode(pixels[:,:,:,:3]) t = vae.encode(pixels)
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeTiled: class VAEEncodeTiled:
@ -361,7 +361,7 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, ) return ({"samples": t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
@ -970,7 +970,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -980,7 +980,7 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
def load_clip(self, clip_name1, clip_name2, type, device="default"): def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.5.1" version = "0.6.0"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.34.9 comfyui-frontend-package==1.35.9
comfyui-workflow-templates==0.7.59 comfyui-workflow-templates==0.7.64
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch
torchsde torchsde

View File

@ -25,7 +25,7 @@ class TestImageStitch:
result = node.stitch(image1, "right", True, 0, "white", image2=None) result = node.stitch(image1, "right", True, 0, "white", image2=None)
assert len(result) == 1 assert len(result.result) == 1
assert torch.equal(result[0], image1) assert torch.equal(result[0], image1)
def test_basic_horizontal_stitch_right(self): def test_basic_horizontal_stitch_right(self):