Merge branch 'comfyanonymous:master' into offloader-maifee

This commit is contained in:
Maifee Ul Asad 2025-11-27 08:47:41 +06:00 committed by GitHub
commit cee75f301a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 3103 additions and 1508 deletions

View File

@ -67,6 +67,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -6,6 +6,7 @@ class LatentFormat:
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
def process_in(self, latent):
@ -181,6 +182,45 @@ class Flux(SD3):
class Flux2(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors =[
[0.0058, 0.0113, 0.0073],
[0.0495, 0.0443, 0.0836],
[-0.0099, 0.0096, 0.0644],
[0.2144, 0.3009, 0.3652],
[0.0166, -0.0039, -0.0054],
[0.0157, 0.0103, -0.0160],
[-0.0398, 0.0902, -0.0235],
[-0.0052, 0.0095, 0.0109],
[-0.3527, -0.2712, -0.1666],
[-0.0301, -0.0356, -0.0180],
[-0.0107, 0.0078, 0.0013],
[0.0746, 0.0090, -0.0941],
[0.0156, 0.0169, 0.0070],
[-0.0034, -0.0040, -0.0114],
[0.0032, 0.0181, 0.0080],
[-0.0939, -0.0008, 0.0186],
[0.0018, 0.0043, 0.0104],
[0.0284, 0.0056, -0.0127],
[-0.0024, -0.0022, -0.0030],
[0.1207, -0.0026, 0.0065],
[0.0128, 0.0101, 0.0142],
[0.0137, -0.0072, -0.0007],
[0.0095, 0.0092, -0.0059],
[0.0000, -0.0077, -0.0049],
[-0.0465, -0.0204, -0.0312],
[0.0095, 0.0012, -0.0066],
[0.0290, -0.0034, 0.0025],
[0.0220, 0.0169, -0.0048],
[-0.0332, -0.0457, -0.0468],
[-0.0085, 0.0389, 0.0609],
[-0.0076, 0.0003, -0.0043],
[-0.0111, -0.0460, -0.0614],
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
def process_in(self, latent):
return latent

View File

@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
@ -31,6 +32,7 @@ class JointAttention(nn.Module):
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
out_bias: bool = False,
operation_settings={},
):
"""
@ -59,7 +61,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim,
dim,
bias=False,
bias=out_bias,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
@ -70,35 +72,6 @@ class JointAttention(nn.Module):
else:
self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward(
self,
x: torch.Tensor,
@ -134,8 +107,7 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
xq, xk = apply_rope(xq, xk, freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
norm_eps: float,
qk_norm: bool,
modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings={},
) -> None:
"""
@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
hidden_dim=dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings,
@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
if z_image_modulation:
self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear(
min(dim, 256),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward(
self,
@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT.
"""
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size,
@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"),
)
if z_image_modulation:
min_mod = 256
else:
min_mod = 1024
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(hidden_size, 1024),
min(hidden_size, min_mod),
hidden_size,
bias=True,
device=operation_settings.get("device"),
@ -373,12 +363,16 @@ class NextDiT(nn.Module):
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
image_model=None,
device=None,
dtype=None,
@ -390,6 +384,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels,
@ -411,6 +407,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
@ -434,7 +431,7 @@ class NextDiT(nn.Module):
]
)
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
@ -457,18 +454,24 @@ class NextDiT(nn.Module):
ffn_dim_multiplier,
norm_eps,
qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings,
)
for layer_id in range(n_layers)
]
)
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads
@ -503,108 +506,42 @@ class NextDiT(nn.Module):
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
dtype = x[0].dtype
if cap_mask is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
else:
l_effective_cap_len = [num_tokens] * bsz
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
if cap_mask is not None and not torch.is_floating_point(cap_mask):
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
max_seq_len = max(
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
# refine image
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
padded_img_mask = None
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@ -627,7 +564,7 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features
"""
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

View File

@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size

View File

@ -926,7 +926,7 @@ class Flux(BaseModel):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class Flux2(Flux):
@ -1114,9 +1114,13 @@ class Lumina2(BaseModel):
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
return out
class WAN21(BaseModel):

View File

@ -416,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
dit_config["axes_dims"] = [32, 48, 48]
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1

View File

@ -132,7 +132,7 @@ class LowVramPatch:
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
weight = self.convert_func(weight, inplace=False)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32

View File

@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
@ -502,7 +504,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
@ -643,6 +645,24 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):

View File

@ -1,6 +1,7 @@
import torch
import logging
from typing import Tuple, Dict
import comfy.float
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
if scale is None:
@ -403,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
return tensor, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):

View File

@ -52,6 +52,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.model_patcher
import comfy.lora
@ -953,6 +954,8 @@ class TEModel(Enum):
GEMMA_3_4B = 13
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -985,6 +988,8 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
@ -1110,6 +1115,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:

View File

@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
@ -164,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all":
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
@ -266,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks:
attention_mask_model = attention_mask
if self.layer == "all":
if isinstance(self.layer, list):
intermediate_output = self.layer
elif self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx

View File

@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
from . import supported_models_base
from . import latent_formats
@ -994,7 +995,7 @@ class Lumina2(supported_models_base.BASE):
"shift": 6.0,
}
memory_usage_factor = 1.2
memory_usage_factor = 1.4
unet_extra_config = {}
latent_format = latent_formats.Flux
@ -1013,6 +1014,24 @@ class Lumina2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
class ZImage(Lumina2):
unet_config = {
"image_model": "lumina2",
"dim": 3840,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
memory_usage_factor = 1.7
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@ -1453,7 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
models += [SVD_img2vid]

View File

@ -87,6 +87,7 @@ def load_mistral_tokenizer(data):
vocab = {}
max_vocab = mistral_vocab["config"]["default_vocab_size"]
max_vocab -= len(mistral_vocab["special_tokens"])
for w in mistral_vocab["vocab"]:
r = w["rank"]
@ -137,7 +138,7 @@ class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
return tokens
class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
@ -153,7 +154,7 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra

View File

@ -78,6 +78,28 @@ class Qwen25_3BConfig:
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
@ -434,8 +456,12 @@ class Llama2_(nn.Module):
intermediate = None
all_intermediate = None
only_layers = None
if intermediate_output is not None:
if intermediate_output == "all":
if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
@ -443,7 +469,8 @@ class Llama2_(nn.Module):
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
@ -457,7 +484,8 @@ class Llama2_(nn.Module):
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or ((i + 1) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
@ -505,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -0,0 +1,48 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None:
weight = weight_decompose(
dora_scale,

View File

@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel):
mimeType: GeminiMimeType | None = Field(None)
class GeminiFileData(BaseModel):
fileUri: str | None = Field(None)
mimeType: GeminiMimeType | None = Field(None)
class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)

View File

@ -1,34 +1,21 @@
from typing import Optional, Union
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class Image2(BaseModel):
bytesBase64Encoded: str
gcsUri: Optional[str] = None
mimeType: Optional[str] = None
class VeoRequestInstanceImage(BaseModel):
bytesBase64Encoded: str | None = Field(None)
gcsUri: str | None = Field(None)
mimeType: str | None = Field(None)
class Image3(BaseModel):
bytesBase64Encoded: Optional[str] = None
gcsUri: str
mimeType: Optional[str] = None
class Instance1(BaseModel):
image: Optional[Union[Image2, Image3]] = Field(
None, description='Optional image to guide video generation'
)
class VeoRequestInstance(BaseModel):
image: VeoRequestInstanceImage | None = Field(None)
lastFrame: VeoRequestInstanceImage | None = Field(None)
prompt: str = Field(..., description='Text description of the video')
class PersonGeneration1(str, Enum):
ALLOW = 'ALLOW'
BLOCK = 'BLOCK'
class Parameters1(BaseModel):
class VeoRequestParameters(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None
@ -37,17 +24,18 @@ class Parameters1(BaseModel):
description='Generate audio for the video. Only supported by veo 3 models.',
)
negativePrompt: Optional[str] = None
personGeneration: Optional[PersonGeneration1] = None
personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
sampleCount: Optional[int] = None
seed: Optional[int] = None
storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video'
)
resolution: str | None = Field(None)
class VeoGenVidRequest(BaseModel):
instances: Optional[list[Instance1]] = None
parameters: Optional[Parameters1] = None
instances: list[VeoRequestInstance] | None = Field(None)
parameters: VeoRequestParameters | None = Field(None)
class VeoGenVidResponse(BaseModel):

View File

@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
"""
import base64
import json
import os
import time
import uuid
from enum import Enum
from io import BytesIO
from typing import Literal
@ -20,6 +17,7 @@ from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiImageConfig,
@ -38,10 +36,10 @@ from comfy_api_nodes.util import (
get_number_of_images,
sync_op,
tensor_to_base64_string,
upload_images_to_comfyapi,
validate_string,
video_to_base64_string,
)
from server import PromptServer
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
@ -68,24 +66,43 @@ class GeminiImageModel(str, Enum):
gemini_2_5_flash_image = "gemini-2.5-flash-image"
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
async def create_image_parts(
cls: type[IO.ComfyNode],
images: torch.Tensor,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
if image_limit < 0:
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
total_images = get_number_of_images(images)
if total_images <= 0:
raise ValueError("No images provided to create_image_parts; at least one image is required.")
# If image_limit == 0 --> use all images; otherwise clamp to image_limit.
effective_max = total_images if image_limit == 0 else min(total_images, image_limit)
# Number of images we'll send as URLs (fileData)
num_url_images = min(effective_max, 10) # Vertex API max number of image links
reference_images_urls = await upload_images_to_comfyapi(
cls,
images,
max_images=num_url_images,
)
for reference_image_url in reference_images_urls:
image_parts.append(
GeminiPart(
fileData=GeminiFileData(
mimeType=GeminiMimeType.image_png,
fileUri=reference_image_url,
)
)
)
for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
data=tensor_to_base64_string(images[idx]),
)
)
)
@ -338,8 +355,7 @@ class GeminiNode(IO.ComfyNode):
# Add other modal parts
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
parts.extend(await create_image_parts(cls, images))
if audio is not None:
parts.extend(cls.create_audio_parts(audio))
if video is not None:
@ -364,29 +380,6 @@ class GeminiNode(IO.ComfyNode):
)
output_text = get_text_from_response(response)
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
@ -562,8 +555,7 @@ class GeminiImage(IO.ComfyNode):
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None:
image_parts = create_image_parts(images)
parts.extend(image_parts)
parts.extend(await create_image_parts(cls, images))
if files is not None:
parts.extend(files)
@ -582,30 +574,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_text = get_text_from_response(response)
if output_text:
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode):
@ -702,7 +671,7 @@ class GeminiImage2(IO.ComfyNode):
if images is not None:
if get_number_of_images(images) > 14:
raise ValueError("The current maximum number of supported images is 14.")
parts.extend(create_image_parts(images))
parts.extend(await create_image_parts(cls, images))
if files is not None:
parts.extend(files)
@ -725,30 +694,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_text = get_text_from_response(response)
if output_text:
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):

View File

@ -1,15 +1,10 @@
from io import BytesIO
from typing import Optional, Union
import json
import os
import time
import uuid
from enum import Enum
from inspect import cleandoc
import numpy as np
import torch
from PIL import Image
from server import PromptServer
import folder_paths
import base64
from comfy_api.latest import IO, ComfyExtension
@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode):
def create_input_message_contents(
cls,
prompt: str,
image: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
image: torch.Tensor | None = None,
files: list[InputFileContent] | None = None,
) -> InputMessageContentList:
"""Create a list of input message contents from prompt and optional image."""
content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [
content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
InputTextContent(text=prompt, type="input_text"),
]
if image is not None:
@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode):
prompt: str,
persist_context: bool = False,
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
images: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
advanced_options: Optional[CreateModelResponseProperties] = None,
images: torch.Tensor | None = None,
files: list[InputFileContent] | None = None,
advanced_options: CreateModelResponseProperties | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode):
status_extractor=lambda response: response.status,
completed_statuses=["incomplete", "completed"]
)
output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))
# Update history
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(output_text)
return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
class OpenAIInputFiles(IO.ComfyNode):
@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode):
def execute(
cls,
truncation: bool,
instructions: Optional[str] = None,
max_output_tokens: Optional[int] = None,
instructions: str | None = None,
max_output_tokens: int | None = None,
) -> IO.NodeOutput:
"""
Configure advanced options for the OpenAI Chat Node.

View File

@ -1,6 +1,7 @@
import base64
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.input_impl.video_types import VideoFromFile
@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollResponse,
VeoGenVidRequest,
VeoGenVidResponse,
VeoRequestInstance,
VeoRequestInstanceImage,
VeoRequestParameters,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
)
class Veo3FirstLastFrameNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Veo3FirstLastFrameNode",
display_name="Google Veo 3 First-Last-Frame to Video",
category="api node/video/Veo",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the video",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
default="16:9",
tooltip="Aspect ratio of the output video",
),
IO.Int.Input(
"duration",
default=8,
min=4,
max=8,
step=2,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFF,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation",
),
IO.Image.Input("first_frame", tooltip="Start frame"),
IO.Image.Input("last_frame", tooltip="End frame"),
IO.Combo.Input(
"model",
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
default="veo-3.1-fast-generate",
),
IO.Boolean.Input(
"generate_audio",
default=True,
tooltip="Generate audio for the video.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
resolution: str,
aspect_ratio: str,
duration: int,
seed: int,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
model: str,
generate_audio: bool,
):
model = MODELS_MAP[model]
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=[
VeoRequestInstance(
prompt=prompt,
image=VeoRequestInstanceImage(
bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png"
),
lastFrame=VeoRequestInstanceImage(
bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png"
),
),
],
parameters=VeoRequestParameters(
aspectRatio=aspect_ratio,
personGeneration="ALLOW",
durationSeconds=duration,
enhancePrompt=True, # cannot be False for Veo3
seed=seed,
generateAudio=generate_audio,
negativePrompt=negative_prompt,
resolution=resolution,
),
),
)
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=lambda r: "completed" if r.done else "pending",
data=VeoGenVidPollRequest(
operationName=initial_response.name,
),
poll_interval=5.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
response = poll_response.response
filtered_count = response.raiMediaFilteredCount
if filtered_count:
reasons = response.raiMediaFilteredReasons or []
reason_part = f": {reasons[0]}" if reasons else ""
raise Exception(
f"Content blocked by Google's Responsible AI filters{reason_part} "
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
)
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class VeoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
VeoVideoGenerationNode,
Veo3VideoGenerationNode,
Veo3FirstLastFrameNode,
]

View File

@ -4,7 +4,7 @@ import logging
import time
import uuid
from io import BytesIO
from typing import Optional, Union
from typing import Optional
from urllib.parse import urlparse
import aiohttp
@ -48,8 +48,9 @@ async def upload_images_to_comfyapi(
image: torch.Tensor,
*,
max_images: int = 8,
mime_type: Optional[str] = None,
wait_label: Optional[str] = "Uploading",
mime_type: str | None = None,
wait_label: str | None = "Uploading",
show_batch_index: bool = True,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
@ -59,11 +60,18 @@ async def upload_images_to_comfyapi(
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
num_to_upload = min(batch_len, max_images)
batch_start_ts = time.monotonic()
for idx in range(min(batch_len, max_images)):
for idx in range(num_to_upload):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
download_urls.append(url)
return download_urls
@ -126,8 +134,9 @@ async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
wait_label: Optional[str] = "Uploading",
upload_mime_type: str | None,
wait_label: str | None = "Uploading",
progress_origin_ts: float | None = None,
) -> str:
"""Uploads a single file to ComfyUI API and returns its download URL."""
if upload_mime_type is None:
@ -148,6 +157,7 @@ async def upload_file_to_comfyapi(
file_bytes_io,
content_type=upload_mime_type,
wait_label=wait_label,
progress_origin_ts=progress_origin_ts,
)
return create_resp.download_url
@ -155,27 +165,18 @@ async def upload_file_to_comfyapi(
async def upload_file(
cls: type[IO.ComfyNode],
upload_url: str,
file: Union[BytesIO, str],
file: BytesIO | str,
*,
content_type: Optional[str] = None,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: Optional[str] = None,
wait_label: str | None = None,
progress_origin_ts: float | None = None,
) -> None:
"""
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
Args:
cls: Node class (provides auth context + UI progress hooks).
upload_url: Pre-signed PUT URL.
file: BytesIO or path string.
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
max_retries: Maximum retry attempts.
retry_delay: Initial delay in seconds.
retry_backoff: Exponential backoff factor.
wait_label: Progress label shown in Comfy UI.
Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
"""
@ -198,7 +199,7 @@ async def upload_file(
attempt = 0
delay = retry_delay
start_ts = time.monotonic()
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
op_uuid = uuid.uuid4().hex[:8]
while True:
attempt += 1

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile
from pathlib import Path
from PIL import Image
import numpy as np
import uuid
def normalize_path(path):
return path.replace('\\', '/')
@ -34,58 +38,6 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation():
@classmethod
def INPUT_TYPES(s):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
"image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
@ -120,7 +72,8 @@ class Preview3D():
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
"camera_info": ("LOAD3D_CAMERA", {}),
"bg_image": ("IMAGE", {})
}}
OUTPUT_NODE = True
@ -133,50 +86,33 @@ class Preview3D():
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
bg_image_path = None
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
img.save(bg_image_path, compress_level=1)
bg_image_path = f"temp/{filename}"
return {
"ui": {
"result": [model_file, camera_info]
}
}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
"result": [model_file, camera_info, bg_image_path]
}
}
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,
"Load3DAnimation": Load3DAnimation,
"Preview3D": Preview3D,
"Preview3DAnimation": Preview3DAnimation
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D",
"Load3DAnimation": "Load 3D - Animation",
"Preview3D": "Preview 3D",
"Preview3DAnimation": "Preview 3D - Animation"
"Load3D": "Load 3D & Animation",
"Preview3D": "Preview 3D & Animation",
}

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.71"
__version__ = "0.3.75"

View File

@ -37,13 +37,16 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
self.latent_rgb_factors_reshape = latent_rgb_factors_reshape
def decode_latent_to_preview(self, x0):
if self.latent_rgb_factors_reshape is not None:
x0 = self.latent_rgb_factors_reshape(x0)
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
@ -85,7 +88,7 @@ def get_previewer(device, latent_format):
if previewer is None:
if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape)
return previewer
def prepare_callback(model, steps, x0_output_dict=None):

View File

@ -2278,6 +2278,7 @@ async def init_builtin_extra_nodes():
"nodes_images.py",
"nodes_video_model.py",
"nodes_train.py",
"nodes_dataset.py",
"nodes_sag.py",
"nodes_perpneg.py",
"nodes_stable3d.py",

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.71"
version = "0.3.75"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.30.6
comfyui-workflow-templates==0.7.9
comfyui-frontend-package==1.32.9
comfyui-workflow-templates==0.7.20
comfyui-embedded-docs==0.3.1
torch
torchsde

View File

@ -174,7 +174,7 @@ def create_block_external_middleware():
else:
response = await handler(request)
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
return response
return block_external_middleware