diff --git a/README.md b/README.md index d97c6181d..56b7966cf 100644 --- a/README.md +++ b/README.md @@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). - [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py index 03b603c70..d9aab5b22 100644 --- a/app/node_replace_manager.py +++ b/app/node_replace_manager.py @@ -46,6 +46,8 @@ class NodeReplaceManager: connections: dict[str, list[tuple[str, str, int]]] = {} need_replacement: set[str] = set() for node_number, node_struct in prompt.items(): + if "class_type" not in node_struct or "inputs" not in node_struct: + continue class_type = node_struct["class_type"] # need replacement if not in NODE_CLASS_MAPPINGS and has replacement if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): diff --git a/comfy/conds.py b/comfy/conds.py index 5af3e93ea..55d8cdd78 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -4,6 +4,25 @@ import comfy.utils import logging +def is_equal(x, y): + if torch.is_tensor(x) and torch.is_tensor(y): + return torch.equal(x, y) + elif isinstance(x, dict) and isinstance(y, dict): + if x.keys() != y.keys(): + return False + return all(is_equal(x[k], y[k]) for k in x) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + if type(x) is not type(y) or len(x) != len(y): + return False + return all(is_equal(a, b) for a, b in zip(x, y)) + else: + try: + return x == y + except Exception: + logging.warning("comparison issue with COND") + return False + + class CONDRegular: def __init__(self, cond): self.cond = cond @@ -84,7 +103,7 @@ class CONDConstant(CONDRegular): return self._copy_with(self.cond) def can_concat(self, other): - if self.cond != other.cond: + if not is_equal(self.cond, other.cond): return False return True diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 2b080aaeb..553fd5b38 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module): def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, - v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module): vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa del vshift_msa, vscale_msa - attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) + attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options) del norm_vx # video cross-attention vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] @@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel): return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] def _process_transformer_blocks( - self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs ): vx = x[0] ax = x[1] @@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=args["v_cross_gate_timestep"], a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], + self_attention_mask=args.get("self_attention_mask"), ) return out @@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel): "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, + "self_attention_mask": self_attention_mask, }, {"original_block": block_wrap}, ) @@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index d61e19d6e..60d760d29 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum import functools +import logging import math from typing import Dict, Optional, Tuple @@ -14,6 +15,8 @@ import comfy.ldm.common_dit from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +logger = logging.getLogger(__name__) + def _log_base(x, base): return np.log(x) / np.log(base) @@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) attn1_input = comfy.ldm.common_dit.rms_norm(x) attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) x.addcmul_(attn1_input, gate_msa) del attn1_input @@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC): """Process input data. Must be implemented by subclasses.""" pass + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Base implementation returns None (no attenuation). Subclasses that + support guide-based attention control should override this. + """ + return None + @abstractmethod - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs): """Process transformer blocks. Must be implemented by subclasses.""" pass @@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC): attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + # Build self-attention mask for per-guide attenuation + self_attention_mask = self._build_guide_self_attention_mask( + x, transformer_options, merged_args + ) + # Process transformer blocks x = self._process_transformer_blocks( - x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + x, context, attention_mask, timestep, pe, + transformer_options=transformer_options, + self_attention_mask=self_attention_mask, + **merged_args, ) # Process output @@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel): pixel_coords = pixel_coords[:, :, grid_mask, ...] kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + + # Compute per-guide surviving token counts from guide_attention_entries. + # Each entry tracks one guide reference; they are appended in order and + # their pre_filter_counts partition the kf_grid_mask. + guide_entries = kwargs.get("guide_attention_entries", None) + if guide_entries: + total_pfc = sum(e["pre_filter_count"] for e in guide_entries) + if total_pfc != len(kf_grid_mask): + raise ValueError( + f"guide pre_filter_counts ({total_pfc}) != " + f"keyframe grid mask length ({len(kf_grid_mask)})" + ) + resolved_entries = [] + offset = 0 + for entry in guide_entries: + pfc = entry["pre_filter_count"] + entry_mask = kf_grid_mask[offset:offset + pfc] + surviving = int(entry_mask.sum().item()) + resolved_entries.append({ + **entry, + "surviving_count": surviving, + }) + offset += pfc + additional_args["resolved_guide_entries"] = resolved_entries + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + # Total surviving guide tokens (all guides) + additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] + x = self.patchify_proj(x) return x, pixel_coords, additional_args - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Reads resolved_guide_entries from merged_args (computed in _process_input) + to build a log-space additive bias mask that attenuates noisy ↔ guide + attention for each guide reference independently. + + Returns None if no attenuation is needed (all strengths == 1.0 and no + spatial masks, or no guide tokens). + """ + if isinstance(x, list): + # AV model: x = [vx, ax]; use vx for token count and device + total_tokens = x[0].shape[1] + device = x[0].device + dtype = x[0].dtype + else: + total_tokens = x.shape[1] + device = x.device + dtype = x.dtype + + num_guide_tokens = merged_args.get("num_guide_tokens", 0) + if num_guide_tokens == 0: + return None + + resolved_entries = merged_args.get("resolved_guide_entries", None) + if not resolved_entries: + return None + + # Check if any attenuation is actually needed + needs_attenuation = any( + e["strength"] < 1.0 or e.get("pixel_mask") is not None + for e in resolved_entries + ) + if not needs_attenuation: + return None + + # Build per-guide-token weights for all tracked guide tokens. + # Guides are appended in order at the end of the sequence. + guide_start = total_tokens - num_guide_tokens + all_weights = [] + total_tracked = 0 + + for entry in resolved_entries: + surviving = entry["surviving_count"] + if surviving == 0: + continue + + strength = entry["strength"] + pixel_mask = entry.get("pixel_mask") + latent_shape = entry.get("latent_shape") + + if pixel_mask is not None and latent_shape is not None: + f_lat, h_lat, w_lat = latent_shape + per_token = self._downsample_mask_to_latent( + pixel_mask.to(device=device, dtype=dtype), + f_lat, h_lat, w_lat, + ) + # per_token shape: (B, f_lat*h_lat*w_lat). + # Collapse batch dim — the mask is assumed identical across the + # batch; validate and take the first element to get (1, tokens). + if per_token.shape[0] > 1: + ref = per_token[0] + for bi in range(1, per_token.shape[0]): + if not torch.equal(ref, per_token[bi]): + logger.warning( + "pixel_mask differs across batch elements; " + "using first element only." + ) + break + per_token = per_token[:1] + # `surviving` is the post-grid_mask token count. + # Clamp to surviving to handle any mismatch safely. + n_weights = min(per_token.shape[1], surviving) + weights = per_token[:, :n_weights] * strength # (1, n_weights) + else: + weights = torch.full( + (1, surviving), strength, device=device, dtype=dtype + ) + + all_weights.append(weights) + total_tracked += weights.shape[1] + + if not all_weights: + return None + + # Concatenate per-token weights for all tracked guides + tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) + + # Check if any weight is actually < 1.0 (otherwise no attenuation needed) + if (tracked_weights >= 1.0).all(): + return None + + # Build the mask: guide tokens are at the end of the sequence. + # Tracked guides come first (in order), untracked follow. + return self._build_self_attention_mask( + total_tokens, num_guide_tokens, total_tracked, + tracked_weights, guide_start, device, dtype, + ) + + @staticmethod + def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): + """Downsample a pixel-space mask to per-token latent weights. + + Args: + mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1]. + f_lat: Number of latent frames (pre-dilation original count). + h_lat: Latent height (pre-dilation original height). + w_lat: Latent width (pre-dilation original width). + + Returns: + (B, F_lat * H_lat * W_lat) flattened per-token weights. + """ + b = mask.shape[0] + f_pix = mask.shape[2] + + # Spatial downsampling: area interpolation per frame + spatial_down = torch.nn.functional.interpolate( + rearrange(mask, "b 1 f h w -> (b f) 1 h w"), + size=(h_lat, w_lat), + mode="area", + ) + spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b) + + # Temporal downsampling: first pixel frame maps to first latent frame, + # remaining pixel frames are averaged in groups for causal temporal structure. + first_frame = spatial_down[:, :, :1, :, :] + if f_pix > 1 and f_lat > 1: + remaining_pix = f_pix - 1 + remaining_lat = f_lat - 1 + t = remaining_pix // remaining_lat + if t < 1: + # Fewer pixel frames than latent frames — upsample by repeating + # the available pixel frames via nearest interpolation. + rest_flat = rearrange( + spatial_down[:, :, 1:, :, :], + "b 1 f h w -> (b h w) 1 f", + ) + rest_up = torch.nn.functional.interpolate( + rest_flat, size=remaining_lat, mode="nearest", + ) + rest = rearrange( + rest_up, "(b h w) 1 f -> b 1 f h w", + b=b, h=h_lat, w=w_lat, + ) + else: + # Trim trailing pixel frames that don't fill a complete group + usable = remaining_lat * t + rest = rearrange( + spatial_down[:, :, 1:1 + usable, :, :], + "b 1 (f t) h w -> b 1 f t h w", + t=t, + ) + rest = rest.mean(dim=3) + latent_mask = torch.cat([first_frame, rest], dim=2) + elif f_lat > 1: + # Single pixel frame but multiple latent frames — repeat the + # single frame across all latent frames. + latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1) + else: + latent_mask = first_frame + + return rearrange(latent_mask, "b 1 f h w -> b (f h w)") + + @staticmethod + def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, + tracked_weights, guide_start, device, dtype): + """Build a log-space additive self-attention bias mask. + + Attenuates attention between noisy tokens and tracked guide tokens. + Untracked guide tokens (at the end of the guide portion) keep full attention. + + Args: + total_tokens: Total sequence length. + num_guide_tokens: Total guide tokens (all guides) at end of sequence. + tracked_count: Number of tracked guide tokens (first in the guide portion). + tracked_weights: (1, tracked_count) tensor, values in [0, 1]. + guide_start: Index where guide tokens begin in the sequence. + device: Target device. + dtype: Target dtype. + + Returns: + (1, 1, total_tokens, total_tokens) additive bias mask. + 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. + """ + finfo = torch.finfo(dtype) + mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) + tracked_end = guide_start + tracked_count + + # Convert weights to log-space bias + w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) + log_w = torch.full_like(w, finfo.min) + positive_mask = w > 0 + if positive_mask.any(): + log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) + + # noisy → tracked guides: each noisy row gets the same per-guide weight + mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) + # tracked guides → noisy: each guide row broadcasts its weight across noisy cols + mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) + + return mask + + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel): def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel): timestep=timestep, pe=pe, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return x diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index fd125ceed..7903c7690 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -459,6 +459,7 @@ class WanVAE(nn.Module): attn_scales=[], temperal_downsample=[True, True, False], image_channels=3, + conv_out_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -474,7 +475,7 @@ class WanVAE(nn.Module): attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): diff --git a/comfy/model_base.py b/comfy/model_base.py index 2f49578f6..8f852e3c6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -76,6 +76,7 @@ class ModelType(Enum): FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 + IMG_TO_IMG_FLOW = 11 def model_sampling(model_config, model_type): @@ -108,6 +109,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLOW_COSMOS: c = comfy.model_sampling.COSMOS_RFLOW s = comfy.model_sampling.ModelSamplingCosmosRFlow + elif model_type == ModelType.IMG_TO_IMG_FLOW: + c = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSampling(s, c): pass @@ -971,6 +974,10 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -1023,6 +1030,10 @@ class LTXAV(BaseModel): if latent_shapes is not None: out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): @@ -1466,6 +1477,12 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN21_FlowRVS(WAN21): + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): + model_config.unet_config["model_type"] = "t2v" + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 30ea03e8e..030ae6980 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -509,6 +509,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if ref_conv_weight is not None: dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + if metadata is not None and "config" in metadata: + dit_config.update(json.loads(metadata["config"]).get("transformer", {})) + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 21b4ce53e..1c9ba8096 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -271,6 +271,7 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.cached_patcher_init: tuple[Callable, tuple] | None = None if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -307,8 +308,15 @@ class ModelPatcher: def get_free_memory(self, device): return comfy.model_management.get_free_memory(device) - def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) + def clone(self, disable_dynamic=False): + class_ = self.__class__ + model = self.model + if self.is_dynamic() and disable_dynamic: + class_ = ModelPatcher + temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model = temp_model_patcher.model + + n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -362,6 +370,8 @@ class ModelPatcher: n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.cached_patcher_init = self.cached_patcher_init + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 2a00ed819..13860e6a2 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -83,6 +83,16 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class IMG_TO_IMG_FLOW(CONST): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + return latent_image + + def inverse_noise_scaling(self, sigma, latent): + return 1.0 - latent + class COSMOS_RFLOW: def calculate_input(self, sigma, noise): sigma = (sigma / (sigma + 1)) diff --git a/comfy/ops.py b/comfy/ops.py index a6c642795..98fec1e1d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,7 +19,7 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +from comfy.cli_args import args, PerformanceFeature import comfy.float import json import comfy.memory_management @@ -296,7 +296,7 @@ class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): - if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: super().__init__(in_features, out_features, bias, device, dtype) return @@ -317,7 +317,7 @@ class disable_weight_init: def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) @@ -827,6 +827,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec else: sd = {} + if not hasattr(self, 'weight'): + logging.warning("Warning: state dict on uninitialized op {}".format(prefix)) + return sd + if self.bias is not None: sd["{}bias".format(prefix)] = self.bias diff --git a/comfy/sd.py b/comfy/sd.py index ce6ca5d17..de119eb8e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -694,8 +694,9 @@ class VAE: self.latent_dim = 3 self.latent_channels = 16 self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.conv_out_channels = sd["decoder.head.2.weight"].shape[0] self.pad_channel_value = 1.0 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -1530,14 +1531,24 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) + if output_model: + out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): +def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, + embedding_directory=embedding_directory, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return model + +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): clip = None clipvision = None vae = None @@ -1586,7 +1597,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: @@ -1637,7 +1649,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1721,7 +1733,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) if not model_management.is_device_cpu(offload_device): model.to(offload_device) model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) @@ -1730,12 +1743,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): logging.info("left over keys in diffusion model: {}".format(left_over)) return model_patcher -def load_diffusion_model(unet_path, model_options={}): +def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) - model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) + model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model def load_unet(unet_path, dtype=None): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c28be1716..6d08ff0a5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1256,6 +1256,16 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class WAN21_FlowRVS(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "flow_rvs", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1667,6 +1677,6 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 64ce64f89..e86ea9f4e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,6 +6,7 @@ import comfy.text_encoders.genmo import torch import comfy.utils import math +import itertools class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -72,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) special_tokens = {"": 262144, "": 106} - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): @@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module): constant /= 2.0 token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) - num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - num_tokens = max(num_tokens, 64) + m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs]) + + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m + num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): diff --git a/comfy/utils.py b/comfy/utils.py index 1558d0a4e..b26bb6faf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,7 +29,7 @@ import itertools from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange -from comfy.cli_args import args, enables_dynamic_vram +from comfy.cli_args import args import json import time import mmap @@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - if enables_dynamic_vram(): + if comfy.memory_management.aimdo_enabled: sd, metadata = load_safetensors(ckpt) if not return_metadata: metadata = None diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 23cbe2372..18455396d 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -27,6 +27,7 @@ class Seedream4TaskCreationRequest(BaseModel): sequential_image_generation: str = Field("disabled") sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) watermark: bool = Field(False) + output_format: str | None = None class ImageTaskCreationResponse(BaseModel): @@ -106,6 +107,7 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [ ("2496x1664 (3:2)", 2496, 1664), ("1664x2496 (2:3)", 1664, 2496), ("3024x1296 (21:9)", 3024, 1296), + ("3072x3072 (1:1)", 3072, 3072), ("4096x4096 (1:1)", 4096, 4096), ("Custom", None, None), ] diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 3304d7e76..639035fef 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -127,9 +127,15 @@ class GeminiImageConfig(BaseModel): imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) + + class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index ae5d59a79..6dbd5984e 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -37,6 +37,12 @@ from comfy_api_nodes.util import ( BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" +SEEDREAM_MODELS = { + "seedream 5.0 lite": "seedream-5-0-260128", + "seedream-4-5-251128": "seedream-4-5-251128", + "seedream-4-0-250828": "seedream-4-0-250828", +} + # Long-running tasks endpoints(e.g., video) BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} @@ -180,14 +186,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 4.5", + display_name="ByteDance Seedream 4.5 & 5.0", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.Combo.Input( "model", - options=["seedream-4-5-251128", "seedream-4-0-250828"], - tooltip="Model name", + options=list(SEEDREAM_MODELS.keys()), ), IO.String.Input( "prompt", @@ -198,7 +203,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Image.Input( "image", tooltip="Input image(s) for image-to-image generation. " - "List of 1-10 images for single or multi-reference generation.", + "Reference image(s) for single or multi-reference generation.", optional=True, ), IO.Combo.Input( @@ -210,8 +215,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode): "width", default=2048, min=1024, - max=4096, - step=8, + max=6240, + step=2, tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -219,8 +224,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode): "height", default=2048, min=1024, - max=4096, - step=8, + max=4992, + step=2, tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -283,7 +288,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["model"]), expr=""" ( - $price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03; + $price := $contains(widgets.model, "5.0 lite") ? 0.035 : + $contains(widgets.model, "4-5") ? 0.04 : 0.03; { "type":"usd", "usd": $price, @@ -309,6 +315,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): watermark: bool = False, fail_on_partial: bool = True, ) -> IO.NodeOutput: + model = SEEDREAM_MODELS[model] validate_string(prompt, strip_whitespace=True, min_length=1) w = h = None for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4: @@ -318,15 +325,12 @@ class ByteDanceSeedreamNode(IO.ComfyNode): if w is None or h is None: w, h = width, height - if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): - raise ValueError( - f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." - ) + out_num_pixels = w * h mp_provided = out_num_pixels / 1_000_000.0 - if "seedream-4-5" in model and out_num_pixels < 3686400: + if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400: raise ValueError( - f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " + f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided." ) if "seedream-4-0" in model and out_num_pixels < 921600: @@ -334,9 +338,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode): f"Minimum image resolution that the selected model can generate is 0.92MP, " f"but {mp_provided:.2f}MP provided." ) + max_pixels = 10_404_496 if "seedream-5-0" in model else 16_777_216 + if out_num_pixels > max_pixels: + raise ValueError( + f"Maximum image resolution for the selected model is {max_pixels / 1_000_000:.2f}MP, " + f"but {mp_provided:.2f}MP provided." + ) n_input_images = get_number_of_images(image) if image is not None else 0 - if n_input_images > 10: - raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") + max_num_of_images = 14 if model == "seedream-5-0-260128" else 10 + if n_input_images > max_num_of_images: + raise ValueError( + f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received." + ) if sequential_image_generation == "auto" and n_input_images + max_images > 15: raise ValueError( "The maximum number of generated images plus the number of reference images cannot exceed 15." @@ -364,6 +377,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): sequential_image_generation=sequential_image_generation, sequential_image_generation_options=Seedream4Options(max_images=max_images), watermark=watermark, + output_format="png" if model == "seedream-5-0-260128" else None, ), ) if len(response.data) == 1: diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index b69285be5..3fe804e0b 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -29,6 +29,7 @@ from comfy_api_nodes.apis.gemini import ( GeminiRole, GeminiSystemInstructionContent, GeminiTextPart, + GeminiThinkingConfig, Modality, ) from comfy_api_nodes.util import ( @@ -55,6 +56,21 @@ GEMINI_IMAGE_SYS_PROMPT = ( "Prioritize generating the visual representation above any text, formatting, or conversational requests." ) +GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution"]), + expr=""" + ( + $m := widgets.model; + $r := widgets.resolution; + $isFlash := $contains($m, "nano banana 2"); + $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; + $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; + $prices := $isFlash ? $flashPrices : $proPrices; + {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + ) + """, +) + class GeminiModel(str, Enum): """ @@ -229,6 +245,10 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 120.0 + elif response.modelVersion == "gemini-3.1-flash-image-preview": + input_tokens_price = 0.5 + output_text_tokens_price = 3.0 + output_image_tokens_price = 60.0 else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price @@ -686,7 +706,7 @@ class GeminiImage2(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gemini-3-pro-image-preview"], + options=["gemini-3-pro-image-preview", "Nano Banana 2 (Gemini 3.1 Flash Image)"], ), IO.Int.Input( "seed", @@ -750,19 +770,7 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), - expr=""" - ( - $r := widgets.resolution; - ($contains($r,"1k") or $contains($r,"2k")) - ? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}} - : $contains($r,"4k") - ? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}} - : {"type":"text","text":"Token-based"} - ) - """, - ), + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, ) @classmethod @@ -779,6 +787,10 @@ class GeminiImage2(IO.ComfyNode): system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + if response_modalities == "IMAGE+TEXT": + raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.") parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -815,6 +827,168 @@ class GeminiImage2(IO.ComfyNode): return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) +class GeminiNanoBanana2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNanoBanana2", + display_name="Nano Banana 2", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["Nano Banana 2 (Gemini 3.1 Flash Image)"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + # "1:4", + # "4:1", + # "8:1", + # "1:8", + ], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=[ + # "512px", + "1K", + "2K", + "4K", + ], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE"], + advanced=True, + ), + IO.Combo.Input( + "thinking_level", + options=["MINIMAL", "HIGH"], + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + thinking_level: str, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=thinking_level), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + + class GeminiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -822,6 +996,7 @@ class GeminiExtension(ComfyExtension): GeminiNode, GeminiImage, GeminiImage2, + GeminiNanoBanana2, GeminiInputFiles, ] diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 75ffb6d80..6d210b307 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -717,11 +717,11 @@ def _render_shader_batch( gl.glUseProgram(0) for tex in input_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in output_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in ping_pong_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) if fbo is not None: gl.glDeleteFramebuffers(1, [fbo]) for pp_fbo in ping_pong_fbos: diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 1eeeec011..32fe921ff 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): + """Append a guide_attention_entry to both positive and negative conditioning. + + Each entry tracks one guide reference for per-reference attention control. + Entries are derived independently from each conditioning to avoid cross-contamination. + """ + new_entry = { + "pre_filter_count": pre_filter_count, + "strength": strength, + "pixel_mask": None, + "latent_shape": latent_shape, + } + results = [] + for cond in (positive, negative): + # Read existing entries from this specific conditioning + existing = [] + for t in cond: + found = t[1].get("guide_attention_entries", None) + if found is not None: + existing = found + break + # Shallow copy and append (no deepcopy needed — entries contain + # only scalars and None for pixel_mask at this call site). + entries = [*existing, new_entry] + results.append(node_helpers.conditioning_set_values( + cond, {"guide_attention_entries": entries} + )) + return results[0], results[1] + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode): scale_factors, ) + # Track this guide for per-reference attention control. + pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] + guide_latent_shape = list(t.shape[2:]) # [F, H, W] + positive, negative = _append_guide_attention_entry( + positive, negative, pre_filter_count, guide_latent_shape, strength=strength, + ) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] - positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) - negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + positive = node_helpers.conditioning_set_values(positive, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) + negative = node_helpers.conditioning_set_values(negative, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 9601a5f76..8bf6a1afa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -52,7 +52,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],), "zsnr": ("BOOLEAN", {"default": False, "advanced": True}), }} @@ -76,6 +76,8 @@ class ModelSamplingDiscrete: sampling_type = comfy.model_sampling.X0 elif sampling == "img_to_img": sampling_type = comfy.model_sampling.IMG_TO_IMG + elif sampling == "img_to_img_flow": + sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSamplingAdvanced(sampling_base, sampling_type): pass diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c67245e7a..4a0f7141a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -79,7 +79,6 @@ class Blur(io.ComfyNode): node_id="ImageBlur", display_name="Image Blur", category="image/postprocessing", - essentials_category="Image Tools", inputs=[ io.Image.Input("image"), io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), @@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + essentials_category="Image Tools", search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 00e9f8b1f..c9e2e0026 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -25,7 +25,7 @@ class TorchCompileModel(io.ComfyNode): @classmethod def execute(cls, model, backend) -> io.NodeOutput: - m = model.clone() + m = model.clone(disable_dynamic=True) set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict}) return io.NodeOutput(m) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index db7c171a2..5c096c232 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode): search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="image/video", - essentials_category="Video Tools", description="Extracts all components from a video: frames, audio, and framerate.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), @@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode): "start time", ], category="image/video", + essentials_category="Video Tools", inputs=[ io.Video.Input("video"), io.Float.Input( diff --git a/comfyui_version.py b/comfyui_version.py index f24c15cc5..6a35c6de3 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.14.1" +__version__ = "0.15.1" diff --git a/main.py b/main.py index 39e605deb..3fe8f0589 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,10 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -173,10 +177,6 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() import comfy.utils diff --git a/nodes.py b/nodes.py index e2fc20d53..bff073e30 100644 --- a/nodes.py +++ b/nodes.py @@ -1925,7 +1925,6 @@ class ImageInvert: class ImageBatch: SEARCH_ALIASES = ["combine images", "merge images", "stack images"] - ESSENTIALS_CATEGORY = "Image Tools" @classmethod def INPUT_TYPES(s): diff --git a/pyproject.toml b/pyproject.toml index 51c3d224d..1b2318273 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.14.1" +version = "0.15.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index c0c662cd5..b5b292980 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.39.16 -comfyui-workflow-templates==0.9.2 -comfyui-embedded-docs==0.4.1 +comfyui-frontend-package==1.39.19 +comfyui-workflow-templates==0.9.4 +comfyui-embedded-docs==0.4.3 torch torchsde torchvision @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.0 +comfy-aimdo>=0.2.2 requests #non essential dependencies: