diff --git a/README.md b/README.md index bae955b1b..b0f62695b 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. + - Minor versions will be used for releases off the master branch. + - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2979b3ca1..1e0f86026 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -143,7 +143,7 @@ class IndexListContextHandler(ContextHandlerABC): # if multiple conds, split based on primary region if self.split_conds_to_windows and len(cond_in) > 1: region = window.get_region_index(len(cond_in)) - logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}") cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 5628e2ba3..e80b1c138 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -625,7 +625,7 @@ class NextDiT(nn.Module): if pooled is not None: pooled = self.clip_text_pooled_proj(pooled) else: - pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 902af30ed..00c597535 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis): class QwenTimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding( @@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module): operations=operations ) - def forward(self, timestep, hidden_states): + self.use_additional_t_cond = use_additional_t_cond + if self.use_additional_t_cond: + self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype) + + def forward(self, timestep, hidden_states, addition_t_cond=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + if self.use_additional_t_cond: + if addition_t_cond is None: + addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long) + timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype) + return timesteps_emb @@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module): num_attention_heads: int = 24, joint_attention_dim: int = 3584, pooled_projection_dim: int = 768, - guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), default_ref_method="index", image_model=None, final_layer=True, + use_additional_t_cond=False, dtype=None, device=None, operations=None, @@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module): self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, + use_additional_t_cond=use_additional_t_cond, dtype=dtype, device=device, operations=operations @@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module): patch_size = self.patch_size hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + t_len = t h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) - return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device) - def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + if t_len > 1: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1) + else: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index + + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2) + return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape + + def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs) def _forward( self, @@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module): timesteps, context, attention_mask=None, - guidance: torch.Tensor = None, ref_latents=None, + additional_t_cond=None, transformer_options={}, control=None, **kwargs @@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module): index = 0 ref_method = kwargs.get("ref_latents_method", self.default_ref_method) index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + negative_ref_method = ref_method == "negative_index" timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 h_offset = 0 w_offset = 0 + elif negative_ref_method: + index -= 1 + h_offset = 0 + w_offset = 0 else: index = 1 h_offset = 0 @@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - if guidance is not None: - guidance = guidance * 1000 - - temb = ( - self.time_text_embed(timestep, hidden_states) - if guidance is None - else self.time_text_embed(timestep, guidance, hidden_states) - ) + temb = self.time_text_embed(timestep, hidden_states, additional_t_cond) patches_replace = transformer_options.get("patches_replace", {}) patches = transformer_options.get("patches", {}) @@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) - hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b8a8454d..c4f3c0639 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,7 +1110,7 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) - clip_text_pooled = kwargs["pooled_output"] # Newbie + clip_text_pooled = kwargs.get("pooled_output", None) # NewBie if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7148c77fd..539e296ed 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: + if ctd_weight is not None: # NewBie dit_config["clip_text_dim"] = ctd_weight.shape[0] + # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 @@ -620,6 +621,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 dit_config["default_ref_method"] = "index_timestep_zero" + if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered + dit_config["use_additional_t_cond"] = True + dit_config["default_ref_method"] = "negative_index" return dit_config if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..1889ab0ac 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,7 @@ import importlib import platform import weakref import gc +import os class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -333,13 +334,15 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] +AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8340d376c..1989ef107 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -984,9 +984,6 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - if denoise_mask is not None: - denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) - noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) @@ -1013,6 +1010,24 @@ class CFGGuider: else: latent_shapes = [latent_image.shape] + if denoise_mask is not None: + if denoise_mask.is_nested: + denoise_masks = denoise_mask.unbind() + denoise_masks = denoise_masks[:len(latent_shapes)] + else: + denoise_masks = [denoise_mask] + + for i in range(len(denoise_masks), len(latent_shapes)): + denoise_masks.append(torch.ones(latent_shapes[i])) + + for i in range(len(denoise_masks)): + denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) + + if len(denoise_masks) > 1: + denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) + else: + denoise_mask = denoise_masks[0] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) diff --git a/comfy/sd.py b/comfy/sd.py index c2a9728f3..7de7dd9c6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.newbie import comfy.model_patcher import comfy.lora @@ -1008,6 +1010,7 @@ class CLIPType(Enum): OVIS = 21 KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 + NEWBIE = 24 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1038,6 +1041,7 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 + JINA_CLIP_2 = 18 def detect_te_model(sd): @@ -1047,6 +1051,8 @@ def detect_te_model(sd): return TEModel.CLIP_H if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: return TEModel.CLIP_L + if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: + return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] if weight.shape[-1] == 4096: @@ -1207,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.JINA_CLIP_2: + clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper + clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper else: # clip_l if clip_type == CLIPType.SD3: @@ -1262,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.NEWBIE: + clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer + if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]: + clip_data_gemma = clip_data[0] + clip_data_jina = clip_data[1] + else: + clip_data_gemma = clip_data[1] + clip_data_jina = clip_data[0] + tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) + tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 962948dae..c512ca5d0 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -513,6 +513,8 @@ class SDTokenizer: self.embedding_size = embedding_size self.embedding_key = embedding_key + self.disable_weights = disable_weights + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. @@ -547,7 +549,7 @@ class SDTokenizer: min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) - if kwargs.get("disable_weights", False): + if kwargs.get("disable_weights", self.disable_weights): parsed_weights = [(text, 1.0)] else: parsed_weights = token_weights(text, 1.0) diff --git a/comfy/text_encoders/jina_clip_2.py b/comfy/text_encoders/jina_clip_2.py new file mode 100644 index 000000000..0cffb6d16 --- /dev/null +++ b/comfy/text_encoders/jina_clip_2.py @@ -0,0 +1,219 @@ +# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation: +# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py +# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py + +from dataclasses import dataclass + +import torch +from torch import nn as nn +from torch.nn import functional as F + +import comfy.model_management +import comfy.ops +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer + +class JinaClip2Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 + super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") + +# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json +@dataclass +class XLMRobertaConfig: + vocab_size: int = 250002 + type_vocab_size: int = 1 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + rotary_emb_base: float = 20000.0 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-05 + bos_token_id: int = 0 + eos_token_id: int = 2 + pad_token_id: int = 1 + +class XLMRobertaEmbeddings(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype) + self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype) + + def forward(self, input_ids=None, embeddings=None): + if input_ids is not None and embeddings is None: + embeddings = self.word_embeddings(input_ids) + + if embeddings is not None: + token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype: + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, q, k): + batch, seqlen, heads, head_dim = q.shape + self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) + + cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim) + sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim) + + def rotate_half(x): + size = x.shape[-1] // 2 + x1, x2 = x[..., :size], x[..., size:] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MHA(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = embed_dim // config.num_attention_heads + + self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device) + self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) + self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None, optimized_attention=None): + qkv = self.Wqkv(x) + batch_size, seq_len, _ = qkv.shape + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + q, k = self.rotary_emb(q, k) + + # NHD -> HND + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) + return self.out_proj(out) + +class MLP(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) + self.activation = F.gelu + self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class Block(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.mixer = MHA(config, device=device, dtype=dtype, ops=ops) + self.dropout1 = nn.Dropout(config.hidden_dropout_prob) + self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) + self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + + def forward(self, hidden_states, mask=None, optimized_attention=None): + mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention) + hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states) + mlp_out = self.mlp(hidden_states) + hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) + return hidden_states + +class XLMRobertaEncoder(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None): + optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) + for layer in self.layers: + hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) + return hidden_states + +class XLMRobertaModel_(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops) + self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + x = self.embeddings(input_ids=input_ids, embeddings=embeds) + x = self.emb_ln(x) + x = self.emb_drop(x) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1])) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + + sequence_output = self.encoder(x, attention_mask=mask) + + # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py + pooled_output = None + if attention_mask is None: + pooled_output = sequence_output.mean(dim=1) + else: + attention_mask = attention_mask.to(sequence_output.dtype) + pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True) + + # Intermediate output is not yet implemented, use None for placeholder + return sequence_output, None, pooled_output + +class XLMRobertaModel(nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.config = XLMRobertaConfig(**config_dict) + self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.model.embeddings.word_embeddings + + def set_input_embeddings(self, embeddings): + self.model.embeddings.word_embeddings = embeddings + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + +class JinaClip2TextModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) + +class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 0d07ac8c6..ed29e014d 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -3,7 +3,6 @@ import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any import math -import logging from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management @@ -177,7 +176,7 @@ class Gemma3_4B_Config: num_key_value_heads: int = 4 max_position_embeddings: int = 131072 rms_norm_eps: float = 1e-6 - rope_theta = [10000.0, 1000000.0] + rope_theta = [1000000.0, 10000.0] transformer_type: str = "gemma3" head_dim = 256 rms_norm_add = True @@ -186,8 +185,8 @@ class Gemma3_4B_Config: rope_dims = None q_norm = "gemma3" k_norm = "gemma3" - sliding_attention = [False, False, False, False, False, 1024] - rope_scale = [1.0, 8.0] + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] final_norm: bool = True class RMSNorm(nn.Module): @@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module): self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) + if config.sliding_attention is not None: self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] else: self.sliding_attention = False @@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module): if self.transformer_type == 'gemma3': if self.sliding_attention: if x.shape[1] > self.sliding_attention: - logging.warning("Warning: sliding attention not implemented, results may be incorrect") + sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) + sliding_mask.tril_(diagonal=-self.sliding_attention) + if attention_mask is not None: + attention_mask = attention_mask + sliding_mask + else: + attention_mask = sliding_mask freqs_cis = freqs_cis[1] else: freqs_cis = freqs_cis[0] diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 7a6cfdab2..b29a7cc87 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -14,7 +14,7 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -33,6 +33,11 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py new file mode 100644 index 000000000..db2324576 --- /dev/null +++ b/comfy/text_encoders/newbie.py @@ -0,0 +1,62 @@ +import torch + +import comfy.model_management +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.lumina2 + +class NewBieTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) + out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + raise NotImplementedError + + def state_dict(self): + return {} + +class NewBieTEModel(torch.nn.Module): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) + self.dtypes = {dtype, dtype_gemma} + + def set_clip_options(self, options): + self.gemma.set_clip_options(options) + self.jina.set_clip_options(options) + + def reset_clip_options(self): + self.gemma.reset_clip_options() + self.jina.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_gemma = token_weight_pairs["gemma"] + token_weight_pairs_jina = token_weight_pairs["jina"] + + gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma) + jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina) + + return gemma_out, jina_pooled, gemma_extra + + def load_sd(self, sd): + if "model.layers.0.self_attn.q_norm.weight" in sd: + return self.gemma.load_sd(sd) + else: + return self.jina.load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class NewBieTEModel_(NewBieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_ diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py index 77cd76f9b..b8c2f618b 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -10,7 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel): size: str | None = Field(None) seed: int | None = Field(0, ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Image2ImageTaskCreationRequest(BaseModel): @@ -21,7 +21,7 @@ class Image2ImageTaskCreationRequest(BaseModel): size: str | None = Field("adaptive") seed: int | None = Field(..., ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Seedream4Options(BaseModel): @@ -37,7 +37,7 @@ class Seedream4TaskCreationRequest(BaseModel): seed: int = Field(..., ge=0, le=2147483647) sequential_image_generation: str = Field("disabled") sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) + watermark: bool = Field(False) class ImageTaskCreationResponse(BaseModel): diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f8edc38c9..d81337dae 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) + uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 8826dea0c..ce077d6b3 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,10 +1,8 @@ -from inspect import cleandoc - import torch from pydantic import BaseModel from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, @@ -28,7 +26,7 @@ from comfy_api_nodes.util import ( ) -def convert_mask_to_image(mask: torch.Tensor): +def convert_mask_to_image(mask: Input.Image): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ @@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor): class FluxProUltraImageNode(IO.ComfyNode): - """ - Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode): node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( "prompt", @@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: torch.Tensor | None = None, + image_prompt: Input.Image | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode): - """ - Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode): node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: torch.Tensor | None = None, + input_image: Input.Image | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextMaxImageNode(FluxKontextProImageNode): - """ - Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio. - """ - DESCRIPTION = cleandoc(__doc__ or "") + DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio." BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" NODE_ID = "FluxKontextMaxImageNode" DISPLAY_NAME = "Flux.1 Kontext [max] Image" class FluxProExpandNode(IO.ComfyNode): - """ - Outpaints image based on prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode): node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, prompt_upsampling: bool, top: int, @@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode): - """ - Inpaints image based on mask and prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode): node_id="FluxProFillNode", display_name="Flux.1 Fill Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), IO.Mask.Input("mask"), @@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - mask: torch.Tensor, + image: Input.Image, + mask: Input.Image, prompt: str, prompt_upsampling: bool, steps: int, @@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode): + NODE_ID = "Flux2ProImageNode" + DISPLAY_NAME = "Flux.2 [pro] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( - node_id="Flux2ProImageNode", - display_name="Flux.2 [pro] Image", + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, category="api node/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ @@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode): ), IO.Boolean.Input( "prompt_upsampling", - default=False, + default=True, tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", + "If active, automatically modifies the prompt for more creative generation.", ), - IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."), ], outputs=[IO.Image.Output()], hidden=[ @@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode): height: int, seed: int, prompt_upsampling: bool, - images: torch.Tensor | None = None, + images: Input.Image | None = None, ) -> IO.NodeOutput: reference_images = {} if images is not None: @@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode): reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) initial_response = await sync_op( cls, - ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + ApiEndpoint(path=cls.API_ENDPOINT, method="POST"), response_model=BFLFluxProGenerateResponse, data=Flux2ProGenerateRequest( prompt=prompt, @@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2MaxImageNode(Flux2ProImageNode): + + NODE_ID = "Flux2MaxImageNode" + DISPLAY_NAME = "Flux.2 [max] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension): FluxProExpandNode, FluxProFillNode, Flux2ProImageNode, + Flux2MaxImageNode, ] diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 57c0218d0..636cc1265 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -112,7 +112,7 @@ class ByteDanceImageNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -215,7 +215,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), @@ -346,7 +346,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), @@ -380,7 +380,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): sequential_image_generation: str = "disabled", max_images: int = 1, seed: int = 0, - watermark: bool = True, + watermark: bool = False, fail_on_partial: bool = True, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) @@ -507,7 +507,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -617,7 +617,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -739,7 +739,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), @@ -862,7 +862,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ad0f4b4d1..e8ed7e797 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -34,6 +34,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + download_url_to_image_tensor, get_number_of_images, sync_op, tensor_to_base64_string, @@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera ) parts = [] for part in response.candidates[0].content.parts: - if part_type == "text" and hasattr(part, "text") and part.text: + if part_type == "text" and part.text: parts.append(part) - elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: - image_data = base64.b64decode(part.inlineData.data) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + if part.inlineData: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + else: + returned_image = await download_url_to_image_tensor(part.fileData.fileUri) image_tensors.append(returned_image) if len(image_tensors) == 0: return torch.zeros((1, 1024, 1024, 4)) @@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode): response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode): response = await sync_op( cls, - ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 1a6364fa0..5294b10d4 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -858,7 +858,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): tooltip="A text prompt describing the video content. " "This can include both positive and negative descriptions.", ), - IO.Combo.Input("duration", options=["5", "10"]), + IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider), IO.Image.Input("first_frame"), IO.Image.Input( "end_frame", @@ -897,6 +897,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): validate_string(prompt, min_length=1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + if duration not in (5, 10) and end_frame is None and reference_images is None: + raise ValueError( + "Duration is only supported for 5 or 10 seconds if there is no end frame or reference images." + ) validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [ diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index f522756e5..b04575ad8 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -23,10 +23,6 @@ UPSCALER_MODELS_MAP = { "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", } -UPSCALER_VALUES_MAP = { - "FullHD (1080p)": 1920, - "4K (2160p)": 3840, -} class TopazImageEnhance(IO.ComfyNode): @@ -214,7 +210,7 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), - IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", options=["low", "middle", "high"], @@ -306,8 +302,33 @@ class TopazVideoEnhance(IO.ComfyNode): target_frame_rate = src_frame_rate filters = [] if upscaler_enabled: - target_width = UPSCALER_VALUES_MAP[upscaler_resolution] - target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + if "1080p" in upscaler_resolution: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 filters.append( topaz_api.VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 17b680e13..1675fd863 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -46,14 +46,14 @@ class Txt2ImageParametersField(BaseModel): n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) class Image2ImageParametersField(BaseModel): size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) - watermark: bool = Field(True) + watermark: bool = Field(False) class Text2VideoParametersField(BaseModel): @@ -61,7 +61,7 @@ class Text2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -71,7 +71,7 @@ class Image2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -208,7 +208,7 @@ class WanTextToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -234,7 +234,7 @@ class WanTextToImageApi(IO.ComfyNode): height: int = 1024, seed: int = 0, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, ): initial_response = await sync_op( cls, @@ -327,7 +327,7 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -353,7 +353,7 @@ class WanImageToImageApi(IO.ComfyNode): # width: int = 1024, # height: int = 1024, seed: int = 0, - watermark: bool = True, + watermark: bool = False, ): n_images = get_number_of_images(image) if n_images not in (1, 2): @@ -476,7 +476,7 @@ class WanTextToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -512,7 +512,7 @@ class WanTextToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if "480p" in size and model == "wan2.6-t2v": @@ -637,7 +637,7 @@ class WanImageToVideoApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, ), @@ -674,7 +674,7 @@ class WanImageToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if get_number_of_images(image) != 1: diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 7ee4caac1..993889d9d 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -760,8 +760,12 @@ class SamplerCustom(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -948,8 +952,12 @@ class SamplerCustomAdvanced(io.ComfyNode): out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index e439b18ef..2815c5ffc 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -5,6 +5,7 @@ import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io import logging +import math def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode): samples_out["samples"] = torch.narrow(s1, dim, index, amount) return io.NodeOutput(samples_out) +class LatentCutToBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentCutToBatch", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["t", "x", "y"]), + io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, dim, slice_size) -> io.NodeOutput: + samples_out = samples.copy() + + s1 = samples["samples"] + + if "x" in dim: + dim = s1.ndim - 1 + elif "y" in dim: + dim = s1.ndim - 2 + elif "t" in dim: + dim = s1.ndim - 3 + + if dim < 2: + return io.NodeOutput(samples) + + s = s1.movedim(dim, 1) + if s.shape[1] < slice_size: + slice_size = s.shape[1] + elif s.shape[1] % slice_size != 0: + s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size] + new_shape = [-1, slice_size] + list(s.shape[2:]) + samples_out["samples"] = s.reshape(new_shape).movedim(1, dim) + return io.NodeOutput(samples_out) + class LatentBatch(io.ComfyNode): @classmethod def define_schema(cls): @@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension): LatentInterpolate, LatentConcat, LatentCut, + LatentCutToBatch, LatentBatch, LatentBatchSeedBehavior, LatentApplyOperation, diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 2a0cfcf18..1355b3c93 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -348,7 +348,7 @@ class ZImageControlPatch: if self.mask is None: mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] else: - mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") if latent_image is None: latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 525239ae5..fde8fac9a 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -3,7 +3,9 @@ import comfy.utils import math from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import nodes class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): return io.NodeOutput(conditioning) +class EmptyQwenImageLayeredLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyQwenImageLayeredLatentImage", + display_name="Empty Qwen Image Layered Latent", + category="latent/qwen", + inputs=[ + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + class QwenExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeQwenImageEdit, TextEncodeQwenImageEditPlus, + EmptyQwenImageLayeredLatentImage, ] diff --git a/comfyui_version.py b/comfyui_version.py index b45309198..1f28e2407 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.5.1" +__version__ = "0.6.0" diff --git a/manager_requirements.txt b/manager_requirements.txt index 5ef0d3a1d..2300f0c70 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b5 +comfyui_manager==4.0.3b7 diff --git a/nodes.py b/nodes.py index b13ceb578..7d83ecb21 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -980,7 +980,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) diff --git a/pyproject.toml b/pyproject.toml index 3a6960811..35a268bd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.5.1" +version = "0.6.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 54696395f..59ac599c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.60 +comfyui-workflow-templates==0.7.63 comfyui-embedded-docs==0.3.1 torch torchsde