mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
Merge branch 'comfyanonymous:master' into offloader-maifee
This commit is contained in:
commit
cee75f301a
@ -67,6 +67,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||||
|
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||||
|
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||||
- Image Editing Models
|
- Image Editing Models
|
||||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ class LatentFormat:
|
|||||||
latent_dimensions = 2
|
latent_dimensions = 2
|
||||||
latent_rgb_factors = None
|
latent_rgb_factors = None
|
||||||
latent_rgb_factors_bias = None
|
latent_rgb_factors_bias = None
|
||||||
|
latent_rgb_factors_reshape = None
|
||||||
taesd_decoder_name = None
|
taesd_decoder_name = None
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
@ -181,6 +182,45 @@ class Flux(SD3):
|
|||||||
class Flux2(LatentFormat):
|
class Flux2(LatentFormat):
|
||||||
latent_channels = 128
|
latent_channels = 128
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.latent_rgb_factors =[
|
||||||
|
[0.0058, 0.0113, 0.0073],
|
||||||
|
[0.0495, 0.0443, 0.0836],
|
||||||
|
[-0.0099, 0.0096, 0.0644],
|
||||||
|
[0.2144, 0.3009, 0.3652],
|
||||||
|
[0.0166, -0.0039, -0.0054],
|
||||||
|
[0.0157, 0.0103, -0.0160],
|
||||||
|
[-0.0398, 0.0902, -0.0235],
|
||||||
|
[-0.0052, 0.0095, 0.0109],
|
||||||
|
[-0.3527, -0.2712, -0.1666],
|
||||||
|
[-0.0301, -0.0356, -0.0180],
|
||||||
|
[-0.0107, 0.0078, 0.0013],
|
||||||
|
[0.0746, 0.0090, -0.0941],
|
||||||
|
[0.0156, 0.0169, 0.0070],
|
||||||
|
[-0.0034, -0.0040, -0.0114],
|
||||||
|
[0.0032, 0.0181, 0.0080],
|
||||||
|
[-0.0939, -0.0008, 0.0186],
|
||||||
|
[0.0018, 0.0043, 0.0104],
|
||||||
|
[0.0284, 0.0056, -0.0127],
|
||||||
|
[-0.0024, -0.0022, -0.0030],
|
||||||
|
[0.1207, -0.0026, 0.0065],
|
||||||
|
[0.0128, 0.0101, 0.0142],
|
||||||
|
[0.0137, -0.0072, -0.0007],
|
||||||
|
[0.0095, 0.0092, -0.0059],
|
||||||
|
[0.0000, -0.0077, -0.0049],
|
||||||
|
[-0.0465, -0.0204, -0.0312],
|
||||||
|
[0.0095, 0.0012, -0.0066],
|
||||||
|
[0.0290, -0.0034, 0.0025],
|
||||||
|
[0.0220, 0.0169, -0.0048],
|
||||||
|
[-0.0332, -0.0457, -0.0468],
|
||||||
|
[-0.0085, 0.0389, 0.0609],
|
||||||
|
[-0.0076, 0.0003, -0.0043],
|
||||||
|
[-0.0111, -0.0460, -0.0614],
|
||||||
|
]
|
||||||
|
|
||||||
|
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||||
|
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
|||||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.flux.math import apply_rope
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +32,7 @@ class JointAttention(nn.Module):
|
|||||||
n_heads: int,
|
n_heads: int,
|
||||||
n_kv_heads: Optional[int],
|
n_kv_heads: Optional[int],
|
||||||
qk_norm: bool,
|
qk_norm: bool,
|
||||||
|
out_bias: bool = False,
|
||||||
operation_settings={},
|
operation_settings={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -59,7 +61,7 @@ class JointAttention(nn.Module):
|
|||||||
self.out = operation_settings.get("operations").Linear(
|
self.out = operation_settings.get("operations").Linear(
|
||||||
n_heads * self.head_dim,
|
n_heads * self.head_dim,
|
||||||
dim,
|
dim,
|
||||||
bias=False,
|
bias=out_bias,
|
||||||
device=operation_settings.get("device"),
|
device=operation_settings.get("device"),
|
||||||
dtype=operation_settings.get("dtype"),
|
dtype=operation_settings.get("dtype"),
|
||||||
)
|
)
|
||||||
@ -70,35 +72,6 @@ class JointAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.q_norm = self.k_norm = nn.Identity()
|
self.q_norm = self.k_norm = nn.Identity()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def apply_rotary_emb(
|
|
||||||
x_in: torch.Tensor,
|
|
||||||
freqs_cis: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply rotary embeddings to input tensors using the given frequency
|
|
||||||
tensor.
|
|
||||||
|
|
||||||
This function applies rotary embeddings to the given query 'xq' and
|
|
||||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
|
||||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
|
||||||
is reshaped for broadcasting compatibility. The resulting tensors
|
|
||||||
contain rotary embeddings and are returned as real tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
|
||||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
|
||||||
exponentials.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
|
||||||
and key tensor with rotary embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
|
||||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
|
||||||
return t_out.reshape(*x_in.shape)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -134,8 +107,7 @@ class JointAttention(nn.Module):
|
|||||||
xq = self.q_norm(xq)
|
xq = self.q_norm(xq)
|
||||||
xk = self.k_norm(xk)
|
xk = self.k_norm(xk)
|
||||||
|
|
||||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
|
||||||
|
|
||||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
if n_rep >= 1:
|
if n_rep >= 1:
|
||||||
@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
|
|||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
qk_norm: bool,
|
qk_norm: bool,
|
||||||
modulation=True,
|
modulation=True,
|
||||||
|
z_image_modulation=False,
|
||||||
|
attn_out_bias=False,
|
||||||
operation_settings={},
|
operation_settings={},
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
|
||||||
self.feed_forward = FeedForward(
|
self.feed_forward = FeedForward(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
hidden_dim=4 * dim,
|
hidden_dim=dim,
|
||||||
multiple_of=multiple_of,
|
multiple_of=multiple_of,
|
||||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.modulation = modulation
|
self.modulation = modulation
|
||||||
if modulation:
|
if modulation:
|
||||||
self.adaLN_modulation = nn.Sequential(
|
if z_image_modulation:
|
||||||
nn.SiLU(),
|
self.adaLN_modulation = nn.Sequential(
|
||||||
operation_settings.get("operations").Linear(
|
operation_settings.get("operations").Linear(
|
||||||
min(dim, 1024),
|
min(dim, 256),
|
||||||
4 * dim,
|
4 * dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
device=operation_settings.get("device"),
|
device=operation_settings.get("device"),
|
||||||
dtype=operation_settings.get("dtype"),
|
dtype=operation_settings.get("dtype"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
min(dim, 1024),
|
||||||
|
4 * dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
|
|||||||
The final layer of NextDiT.
|
The final layer of NextDiT.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_final = operation_settings.get("operations").LayerNorm(
|
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
|
|||||||
dtype=operation_settings.get("dtype"),
|
dtype=operation_settings.get("dtype"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if z_image_modulation:
|
||||||
|
min_mod = 256
|
||||||
|
else:
|
||||||
|
min_mod = 1024
|
||||||
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operation_settings.get("operations").Linear(
|
operation_settings.get("operations").Linear(
|
||||||
min(hidden_size, 1024),
|
min(hidden_size, min_mod),
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
device=operation_settings.get("device"),
|
device=operation_settings.get("device"),
|
||||||
@ -373,12 +363,16 @@ class NextDiT(nn.Module):
|
|||||||
n_heads: int = 32,
|
n_heads: int = 32,
|
||||||
n_kv_heads: Optional[int] = None,
|
n_kv_heads: Optional[int] = None,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
ffn_dim_multiplier: Optional[float] = None,
|
ffn_dim_multiplier: float = 4.0,
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
cap_feat_dim: int = 5120,
|
cap_feat_dim: int = 5120,
|
||||||
axes_dims: List[int] = (16, 56, 56),
|
axes_dims: List[int] = (16, 56, 56),
|
||||||
axes_lens: List[int] = (1, 512, 512),
|
axes_lens: List[int] = (1, 512, 512),
|
||||||
|
rope_theta=10000.0,
|
||||||
|
z_image_modulation=False,
|
||||||
|
time_scale=1.0,
|
||||||
|
pad_tokens_multiple=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -390,6 +384,8 @@ class NextDiT(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = in_channels
|
self.out_channels = in_channels
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.pad_tokens_multiple = pad_tokens_multiple
|
||||||
|
|
||||||
self.x_embedder = operation_settings.get("operations").Linear(
|
self.x_embedder = operation_settings.get("operations").Linear(
|
||||||
in_features=patch_size * patch_size * in_channels,
|
in_features=patch_size * patch_size * in_channels,
|
||||||
@ -411,6 +407,7 @@ class NextDiT(nn.Module):
|
|||||||
norm_eps,
|
norm_eps,
|
||||||
qk_norm,
|
qk_norm,
|
||||||
modulation=True,
|
modulation=True,
|
||||||
|
z_image_modulation=z_image_modulation,
|
||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
for layer_id in range(n_refiner_layers)
|
for layer_id in range(n_refiner_layers)
|
||||||
@ -434,7 +431,7 @@ class NextDiT(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
|
||||||
self.cap_embedder = nn.Sequential(
|
self.cap_embedder = nn.Sequential(
|
||||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
operation_settings.get("operations").Linear(
|
operation_settings.get("operations").Linear(
|
||||||
@ -457,18 +454,24 @@ class NextDiT(nn.Module):
|
|||||||
ffn_dim_multiplier,
|
ffn_dim_multiplier,
|
||||||
norm_eps,
|
norm_eps,
|
||||||
qk_norm,
|
qk_norm,
|
||||||
|
z_image_modulation=z_image_modulation,
|
||||||
|
attn_out_bias=False,
|
||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
for layer_id in range(n_layers)
|
for layer_id in range(n_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
if self.pad_tokens_multiple is not None:
|
||||||
|
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||||
|
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||||
|
|
||||||
assert (dim // n_heads) == sum(axes_dims)
|
assert (dim // n_heads) == sum(axes_dims)
|
||||||
self.axes_dims = axes_dims
|
self.axes_dims = axes_dims
|
||||||
self.axes_lens = axes_lens
|
self.axes_lens = axes_lens
|
||||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
|
||||||
@ -503,108 +506,42 @@ class NextDiT(nn.Module):
|
|||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
pH = pW = self.patch_size
|
pH = pW = self.patch_size
|
||||||
device = x[0].device
|
device = x[0].device
|
||||||
dtype = x[0].dtype
|
|
||||||
|
|
||||||
if cap_mask is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||||
else:
|
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||||
l_effective_cap_len = [num_tokens] * bsz
|
|
||||||
|
|
||||||
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||||
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||||
|
|
||||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
B, C, H, W = x.shape
|
||||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
max_seq_len = max(
|
H_tokens, W_tokens = H // pH, W // pW
|
||||||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||||
)
|
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||||
max_cap_len = max(l_effective_cap_len)
|
x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||||
max_img_len = max(l_effective_img_len)
|
x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||||
|
|
||||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
|
if self.pad_tokens_multiple is not None:
|
||||||
|
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||||
|
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||||
|
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||||
|
|
||||||
for i in range(bsz):
|
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||||
cap_len = l_effective_cap_len[i]
|
|
||||||
img_len = l_effective_img_len[i]
|
|
||||||
H, W = img_sizes[i]
|
|
||||||
H_tokens, W_tokens = H // pH, W // pW
|
|
||||||
assert H_tokens * W_tokens == img_len
|
|
||||||
|
|
||||||
rope_options = transformer_options.get("rope_options", None)
|
|
||||||
h_scale = 1.0
|
|
||||||
w_scale = 1.0
|
|
||||||
h_start = 0
|
|
||||||
w_start = 0
|
|
||||||
if rope_options is not None:
|
|
||||||
h_scale = rope_options.get("scale_y", 1.0)
|
|
||||||
w_scale = rope_options.get("scale_x", 1.0)
|
|
||||||
|
|
||||||
h_start = rope_options.get("shift_y", 0.0)
|
|
||||||
w_start = rope_options.get("shift_x", 0.0)
|
|
||||||
|
|
||||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
|
|
||||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
|
||||||
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
|
||||||
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
|
||||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
|
||||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
|
||||||
|
|
||||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
|
||||||
|
|
||||||
# build freqs_cis for cap and image individually
|
|
||||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
|
||||||
# cap_freqs_cis_shape[1] = max_cap_len
|
|
||||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
|
||||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
|
||||||
|
|
||||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
|
||||||
img_freqs_cis_shape[1] = max_img_len
|
|
||||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
|
||||||
|
|
||||||
for i in range(bsz):
|
|
||||||
cap_len = l_effective_cap_len[i]
|
|
||||||
img_len = l_effective_img_len[i]
|
|
||||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
|
||||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
|
||||||
|
|
||||||
# refine context
|
# refine context
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||||
|
|
||||||
# refine image
|
padded_img_mask = None
|
||||||
flat_x = []
|
|
||||||
for i in range(bsz):
|
|
||||||
img = x[i]
|
|
||||||
C, H, W = img.size()
|
|
||||||
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
|
||||||
flat_x.append(img)
|
|
||||||
x = flat_x
|
|
||||||
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
|
||||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
|
||||||
for i in range(bsz):
|
|
||||||
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
|
||||||
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
|
||||||
|
|
||||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
|
||||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
|
||||||
for layer in self.noise_refiner:
|
for layer in self.noise_refiner:
|
||||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||||
|
|
||||||
if cap_mask is not None:
|
|
||||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
|
||||||
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
|
|
||||||
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
|
||||||
for i in range(bsz):
|
|
||||||
cap_len = l_effective_cap_len[i]
|
|
||||||
img_len = l_effective_img_len[i]
|
|
||||||
|
|
||||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
|
||||||
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
|
||||||
|
|
||||||
|
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||||
|
mask = None
|
||||||
|
img_sizes = [(H, W)] * bsz
|
||||||
|
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
@ -627,7 +564,7 @@ class NextDiT(nn.Module):
|
|||||||
y: (N,) tensor of text tokens/features
|
y: (N,) tensor of text tokens/features
|
||||||
"""
|
"""
|
||||||
|
|
||||||
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||||
adaln_input = t
|
adaln_input = t
|
||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|||||||
@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
|
|||||||
Embeds scalar timesteps into vector representations.
|
Embeds scalar timesteps into vector representations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if output_size is None:
|
||||||
|
output_size = hidden_size
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
|||||||
@ -926,7 +926,7 @@ class Flux(BaseModel):
|
|||||||
out = {}
|
out = {}
|
||||||
ref_latents = kwargs.get("reference_latents", None)
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class Flux2(Flux):
|
class Flux2(Flux):
|
||||||
@ -1114,9 +1114,13 @@ class Lumina2(BaseModel):
|
|||||||
if torch.numel(attention_mask) != attention_mask.sum():
|
if torch.numel(attention_mask) != attention_mask.sum():
|
||||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||||
|
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
if 'num_tokens' not in out:
|
||||||
|
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
|
|||||||
@ -416,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["image_model"] = "lumina2"
|
dit_config["image_model"] = "lumina2"
|
||||||
dit_config["patch_size"] = 2
|
dit_config["patch_size"] = 2
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
dit_config["dim"] = 2304
|
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
|
||||||
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
|
dit_config["dim"] = w.shape[0]
|
||||||
|
dit_config["cap_feat_dim"] = w.shape[1]
|
||||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||||
dit_config["n_heads"] = 24
|
|
||||||
dit_config["n_kv_heads"] = 8
|
|
||||||
dit_config["qk_norm"] = True
|
dit_config["qk_norm"] = True
|
||||||
dit_config["axes_dims"] = [32, 32, 32]
|
|
||||||
dit_config["axes_lens"] = [300, 512, 512]
|
if dit_config["dim"] == 2304: # Original Lumina 2
|
||||||
|
dit_config["n_heads"] = 24
|
||||||
|
dit_config["n_kv_heads"] = 8
|
||||||
|
dit_config["axes_dims"] = [32, 32, 32]
|
||||||
|
dit_config["axes_lens"] = [300, 512, 512]
|
||||||
|
dit_config["rope_theta"] = 10000.0
|
||||||
|
dit_config["ffn_dim_multiplier"] = 4.0
|
||||||
|
elif dit_config["dim"] == 3840: # Z image
|
||||||
|
dit_config["n_heads"] = 30
|
||||||
|
dit_config["n_kv_heads"] = 30
|
||||||
|
dit_config["axes_dims"] = [32, 48, 48]
|
||||||
|
dit_config["axes_lens"] = [1536, 512, 512]
|
||||||
|
dit_config["rope_theta"] = 256.0
|
||||||
|
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||||
|
dit_config["z_image_modulation"] = True
|
||||||
|
dit_config["time_scale"] = 1000.0
|
||||||
|
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["pad_tokens_multiple"] = 32
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
|
|||||||
@ -132,7 +132,7 @@ class LowVramPatch:
|
|||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
intermediate_dtype = weight.dtype
|
intermediate_dtype = weight.dtype
|
||||||
if self.convert_func is not None:
|
if self.convert_func is not None:
|
||||||
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
weight = self.convert_func(weight, inplace=False)
|
||||||
|
|
||||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
intermediate_dtype = torch.float32
|
intermediate_dtype = torch.float32
|
||||||
|
|||||||
22
comfy/ops.py
22
comfy/ops.py
@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if weight_has_function or weight.dtype != dtype:
|
if weight_has_function or weight.dtype != dtype:
|
||||||
with wf_context:
|
with wf_context:
|
||||||
weight = weight.to(dtype=dtype)
|
weight = weight.to(dtype=dtype)
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
@ -502,7 +504,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
return weight
|
return weight
|
||||||
else:
|
else:
|
||||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
@ -643,6 +645,24 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
return weight.dequantize()
|
||||||
|
else:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||||
|
else:
|
||||||
|
weight = weight.to(self.weight.dtype)
|
||||||
|
if return_weight:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||||
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, Dict
|
from typing import Tuple, Dict
|
||||||
|
import comfy.float
|
||||||
|
|
||||||
_LAYOUT_REGISTRY = {}
|
_LAYOUT_REGISTRY = {}
|
||||||
_GENERIC_UTILS = {}
|
_GENERIC_UTILS = {}
|
||||||
@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
- orig_dtype: Original dtype before quantization (for casting back)
|
- orig_dtype: Original dtype before quantization (for casting back)
|
||||||
"""
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
@ -403,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
scale = torch.tensor(scale)
|
scale = torch.tensor(scale)
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
if inplace_ops:
|
||||||
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||||
lp_amax = torch.finfo(dtype).max
|
else:
|
||||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
|
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||||
|
else:
|
||||||
|
lp_amax = torch.finfo(dtype).max
|
||||||
|
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
||||||
|
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': scale,
|
'scale': scale,
|
||||||
'orig_dtype': orig_dtype
|
'orig_dtype': orig_dtype
|
||||||
}
|
}
|
||||||
return qdata, layout_params
|
return tensor, layout_params
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
|||||||
@ -52,6 +52,7 @@ import comfy.text_encoders.ace
|
|||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
|
import comfy.text_encoders.z_image
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -953,6 +954,8 @@ class TEModel(Enum):
|
|||||||
GEMMA_3_4B = 13
|
GEMMA_3_4B = 13
|
||||||
MISTRAL3_24B = 14
|
MISTRAL3_24B = 14
|
||||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||||
|
QWEN3_4B = 16
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
@ -985,6 +988,8 @@ def detect_te_model(sd):
|
|||||||
if weight.shape[0] == 512:
|
if weight.shape[0] == 512:
|
||||||
return TEModel.QWEN25_7B
|
return TEModel.QWEN25_7B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
|
return TEModel.QWEN3_4B
|
||||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||||
if weight.shape[0] == 5120:
|
if weight.shape[0] == 5120:
|
||||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||||
@ -1110,6 +1115,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||||
|
elif te_model == TEModel.QWEN3_4B:
|
||||||
|
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
|
|||||||
@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
|
||||||
|
|
||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
@ -164,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
layer_idx = options.get("layer", self.layer_idx)
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||||
if self.layer == "all":
|
if isinstance(self.layer, list) or self.layer == "all":
|
||||||
pass
|
pass
|
||||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
self.layer = "last"
|
self.layer = "last"
|
||||||
@ -266,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
attention_mask_model = attention_mask
|
attention_mask_model = attention_mask
|
||||||
|
|
||||||
if self.layer == "all":
|
if isinstance(self.layer, list):
|
||||||
|
intermediate_output = self.layer
|
||||||
|
elif self.layer == "all":
|
||||||
intermediate_output = "all"
|
intermediate_output = "all"
|
||||||
else:
|
else:
|
||||||
intermediate_output = self.layer_idx
|
intermediate_output = self.layer_idx
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import comfy.text_encoders.ace
|
|||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
|
import comfy.text_encoders.z_image
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -994,7 +995,7 @@ class Lumina2(supported_models_base.BASE):
|
|||||||
"shift": 6.0,
|
"shift": 6.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 1.2
|
memory_usage_factor = 1.4
|
||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
@ -1013,6 +1014,24 @@ class Lumina2(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
class ZImage(Lumina2):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "lumina2",
|
||||||
|
"dim": 3840,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 1.7
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
class WAN21_T2V(supported_models_base.BASE):
|
class WAN21_T2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1453,7 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
||||||
|
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -87,6 +87,7 @@ def load_mistral_tokenizer(data):
|
|||||||
vocab = {}
|
vocab = {}
|
||||||
|
|
||||||
max_vocab = mistral_vocab["config"]["default_vocab_size"]
|
max_vocab = mistral_vocab["config"]["default_vocab_size"]
|
||||||
|
max_vocab -= len(mistral_vocab["special_tokens"])
|
||||||
|
|
||||||
for w in mistral_vocab["vocab"]:
|
for w in mistral_vocab["vocab"]:
|
||||||
r = w["rank"]
|
r = w["rank"]
|
||||||
@ -137,7 +138,7 @@ class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
textmodel_json_config = {}
|
textmodel_json_config = {}
|
||||||
num_layers = model_options.get("num_layers", None)
|
num_layers = model_options.get("num_layers", None)
|
||||||
if num_layers is not None:
|
if num_layers is not None:
|
||||||
@ -153,7 +154,7 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
|||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
|
||||||
out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1)
|
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
|
||||||
out = out.movedim(1, 2)
|
out = out.movedim(1, 2)
|
||||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|||||||
@ -78,6 +78,28 @@ class Qwen25_3BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Qwen3_4BConfig:
|
||||||
|
vocab_size: int = 151936
|
||||||
|
hidden_size: int = 2560
|
||||||
|
intermediate_size: int = 9728
|
||||||
|
num_hidden_layers: int = 36
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 40960
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = "gemma3"
|
||||||
|
k_norm = "gemma3"
|
||||||
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_7BVLI_Config:
|
class Qwen25_7BVLI_Config:
|
||||||
vocab_size: int = 152064
|
vocab_size: int = 152064
|
||||||
@ -434,8 +456,12 @@ class Llama2_(nn.Module):
|
|||||||
|
|
||||||
intermediate = None
|
intermediate = None
|
||||||
all_intermediate = None
|
all_intermediate = None
|
||||||
|
only_layers = None
|
||||||
if intermediate_output is not None:
|
if intermediate_output is not None:
|
||||||
if intermediate_output == "all":
|
if isinstance(intermediate_output, list):
|
||||||
|
all_intermediate = []
|
||||||
|
only_layers = set(intermediate_output)
|
||||||
|
elif intermediate_output == "all":
|
||||||
all_intermediate = []
|
all_intermediate = []
|
||||||
intermediate_output = None
|
intermediate_output = None
|
||||||
elif intermediate_output < 0:
|
elif intermediate_output < 0:
|
||||||
@ -443,7 +469,8 @@ class Llama2_(nn.Module):
|
|||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
all_intermediate.append(x.unsqueeze(1).clone())
|
if only_layers is None or (i in only_layers):
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
x = layer(
|
x = layer(
|
||||||
x=x,
|
x=x,
|
||||||
attention_mask=mask,
|
attention_mask=mask,
|
||||||
@ -457,7 +484,8 @@ class Llama2_(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
all_intermediate.append(x.unsqueeze(1).clone())
|
if only_layers is None or ((i + 1) in only_layers):
|
||||||
|
all_intermediate.append(x.unsqueeze(1).clone())
|
||||||
|
|
||||||
if all_intermediate is not None:
|
if all_intermediate is not None:
|
||||||
intermediate = torch.cat(all_intermediate, dim=1)
|
intermediate = torch.cat(all_intermediate, dim=1)
|
||||||
@ -505,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Qwen3_4BConfig(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
48
comfy/text_encoders/z_image.py
Normal file
48
comfy/text_encoders/z_image.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import os
|
||||||
|
|
||||||
|
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
|
||||||
|
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||||
|
if llama_template is None:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3_4BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
||||||
|
class ZImageTEModel_(ZImageTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return ZImageTEModel_
|
||||||
@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
lora_diff = torch.mm(
|
lora_diff = torch.mm(
|
||||||
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||||
).reshape(weight.shape)
|
).reshape(weight.shape)
|
||||||
|
del mat1, mat2
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(
|
weight = weight_decompose(
|
||||||
dora_scale,
|
dora_scale,
|
||||||
|
|||||||
@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel):
|
|||||||
mimeType: GeminiMimeType | None = Field(None)
|
mimeType: GeminiMimeType | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiFileData(BaseModel):
|
||||||
|
fileUri: str | None = Field(None)
|
||||||
|
mimeType: GeminiMimeType | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiPart(BaseModel):
|
class GeminiPart(BaseModel):
|
||||||
inlineData: GeminiInlineData | None = Field(None)
|
inlineData: GeminiInlineData | None = Field(None)
|
||||||
|
fileData: GeminiFileData | None = Field(None)
|
||||||
text: str | None = Field(None)
|
text: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,34 +1,21 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class Image2(BaseModel):
|
class VeoRequestInstanceImage(BaseModel):
|
||||||
bytesBase64Encoded: str
|
bytesBase64Encoded: str | None = Field(None)
|
||||||
gcsUri: Optional[str] = None
|
gcsUri: str | None = Field(None)
|
||||||
mimeType: Optional[str] = None
|
mimeType: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image3(BaseModel):
|
class VeoRequestInstance(BaseModel):
|
||||||
bytesBase64Encoded: Optional[str] = None
|
image: VeoRequestInstanceImage | None = Field(None)
|
||||||
gcsUri: str
|
lastFrame: VeoRequestInstanceImage | None = Field(None)
|
||||||
mimeType: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Instance1(BaseModel):
|
|
||||||
image: Optional[Union[Image2, Image3]] = Field(
|
|
||||||
None, description='Optional image to guide video generation'
|
|
||||||
)
|
|
||||||
prompt: str = Field(..., description='Text description of the video')
|
prompt: str = Field(..., description='Text description of the video')
|
||||||
|
|
||||||
|
|
||||||
class PersonGeneration1(str, Enum):
|
class VeoRequestParameters(BaseModel):
|
||||||
ALLOW = 'ALLOW'
|
|
||||||
BLOCK = 'BLOCK'
|
|
||||||
|
|
||||||
|
|
||||||
class Parameters1(BaseModel):
|
|
||||||
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||||
durationSeconds: Optional[int] = None
|
durationSeconds: Optional[int] = None
|
||||||
enhancePrompt: Optional[bool] = None
|
enhancePrompt: Optional[bool] = None
|
||||||
@ -37,17 +24,18 @@ class Parameters1(BaseModel):
|
|||||||
description='Generate audio for the video. Only supported by veo 3 models.',
|
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||||
)
|
)
|
||||||
negativePrompt: Optional[str] = None
|
negativePrompt: Optional[str] = None
|
||||||
personGeneration: Optional[PersonGeneration1] = None
|
personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
|
||||||
sampleCount: Optional[int] = None
|
sampleCount: Optional[int] = None
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
storageUri: Optional[str] = Field(
|
storageUri: Optional[str] = Field(
|
||||||
None, description='Optional Cloud Storage URI to upload the video'
|
None, description='Optional Cloud Storage URI to upload the video'
|
||||||
)
|
)
|
||||||
|
resolution: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VeoGenVidRequest(BaseModel):
|
class VeoGenVidRequest(BaseModel):
|
||||||
instances: Optional[list[Instance1]] = None
|
instances: list[VeoRequestInstance] | None = Field(None)
|
||||||
parameters: Optional[Parameters1] = None
|
parameters: VeoRequestParameters | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VeoGenVidResponse(BaseModel):
|
class VeoGenVidResponse(BaseModel):
|
||||||
|
|||||||
@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@ -20,6 +17,7 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
|||||||
from comfy_api.util import VideoCodec, VideoContainer
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
from comfy_api_nodes.apis.gemini_api import (
|
from comfy_api_nodes.apis.gemini_api import (
|
||||||
GeminiContent,
|
GeminiContent,
|
||||||
|
GeminiFileData,
|
||||||
GeminiGenerateContentRequest,
|
GeminiGenerateContentRequest,
|
||||||
GeminiGenerateContentResponse,
|
GeminiGenerateContentResponse,
|
||||||
GeminiImageConfig,
|
GeminiImageConfig,
|
||||||
@ -38,10 +36,10 @@ from comfy_api_nodes.util import (
|
|||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
video_to_base64_string,
|
video_to_base64_string,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||||
@ -68,24 +66,43 @@ class GeminiImageModel(str, Enum):
|
|||||||
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
||||||
|
|
||||||
|
|
||||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
async def create_image_parts(
|
||||||
"""
|
cls: type[IO.ComfyNode],
|
||||||
Convert image tensor input to Gemini API compatible parts.
|
images: torch.Tensor,
|
||||||
|
image_limit: int = 0,
|
||||||
Args:
|
) -> list[GeminiPart]:
|
||||||
image_input: Batch of image tensors from ComfyUI.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of GeminiPart objects containing the encoded images.
|
|
||||||
"""
|
|
||||||
image_parts: list[GeminiPart] = []
|
image_parts: list[GeminiPart] = []
|
||||||
for image_index in range(image_input.shape[0]):
|
if image_limit < 0:
|
||||||
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
|
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
|
||||||
|
total_images = get_number_of_images(images)
|
||||||
|
if total_images <= 0:
|
||||||
|
raise ValueError("No images provided to create_image_parts; at least one image is required.")
|
||||||
|
|
||||||
|
# If image_limit == 0 --> use all images; otherwise clamp to image_limit.
|
||||||
|
effective_max = total_images if image_limit == 0 else min(total_images, image_limit)
|
||||||
|
|
||||||
|
# Number of images we'll send as URLs (fileData)
|
||||||
|
num_url_images = min(effective_max, 10) # Vertex API max number of image links
|
||||||
|
reference_images_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
images,
|
||||||
|
max_images=num_url_images,
|
||||||
|
)
|
||||||
|
for reference_image_url in reference_images_urls:
|
||||||
|
image_parts.append(
|
||||||
|
GeminiPart(
|
||||||
|
fileData=GeminiFileData(
|
||||||
|
mimeType=GeminiMimeType.image_png,
|
||||||
|
fileUri=reference_image_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for idx in range(num_url_images, effective_max):
|
||||||
image_parts.append(
|
image_parts.append(
|
||||||
GeminiPart(
|
GeminiPart(
|
||||||
inlineData=GeminiInlineData(
|
inlineData=GeminiInlineData(
|
||||||
mimeType=GeminiMimeType.image_png,
|
mimeType=GeminiMimeType.image_png,
|
||||||
data=image_as_b64,
|
data=tensor_to_base64_string(images[idx]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -338,8 +355,7 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
|
|
||||||
# Add other modal parts
|
# Add other modal parts
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_parts = create_image_parts(images)
|
parts.extend(await create_image_parts(cls, images))
|
||||||
parts.extend(image_parts)
|
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
parts.extend(cls.create_audio_parts(audio))
|
parts.extend(cls.create_audio_parts(audio))
|
||||||
if video is not None:
|
if video is not None:
|
||||||
@ -364,29 +380,6 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output_text = get_text_from_response(response)
|
output_text = get_text_from_response(response)
|
||||||
if output_text:
|
|
||||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
|
||||||
render_spec = {
|
|
||||||
"node_id": cls.hidden.unique_id,
|
|
||||||
"component": "ChatHistoryWidget",
|
|
||||||
"props": {
|
|
||||||
"history": json.dumps(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"prompt": prompt,
|
|
||||||
"response": output_text,
|
|
||||||
"response_id": str(uuid.uuid4()),
|
|
||||||
"timestamp": time.time(),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
PromptServer.instance.send_sync(
|
|
||||||
"display_component",
|
|
||||||
render_spec,
|
|
||||||
)
|
|
||||||
|
|
||||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||||
|
|
||||||
|
|
||||||
@ -562,8 +555,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
|
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_parts = create_image_parts(images)
|
parts.extend(await create_image_parts(cls, images))
|
||||||
parts.extend(image_parts)
|
|
||||||
if files is not None:
|
if files is not None:
|
||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
@ -582,30 +574,7 @@ class GeminiImage(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
|
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||||
output_text = get_text_from_response(response)
|
|
||||||
if output_text:
|
|
||||||
render_spec = {
|
|
||||||
"node_id": cls.hidden.unique_id,
|
|
||||||
"component": "ChatHistoryWidget",
|
|
||||||
"props": {
|
|
||||||
"history": json.dumps(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"prompt": prompt,
|
|
||||||
"response": output_text,
|
|
||||||
"response_id": str(uuid.uuid4()),
|
|
||||||
"timestamp": time.time(),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
PromptServer.instance.send_sync(
|
|
||||||
"display_component",
|
|
||||||
render_spec,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(get_image_from_response(response), output_text)
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiImage2(IO.ComfyNode):
|
class GeminiImage2(IO.ComfyNode):
|
||||||
@ -702,7 +671,7 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
if images is not None:
|
if images is not None:
|
||||||
if get_number_of_images(images) > 14:
|
if get_number_of_images(images) > 14:
|
||||||
raise ValueError("The current maximum number of supported images is 14.")
|
raise ValueError("The current maximum number of supported images is 14.")
|
||||||
parts.extend(create_image_parts(images))
|
parts.extend(await create_image_parts(cls, images))
|
||||||
if files is not None:
|
if files is not None:
|
||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
@ -725,30 +694,7 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
|
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||||
output_text = get_text_from_response(response)
|
|
||||||
if output_text:
|
|
||||||
render_spec = {
|
|
||||||
"node_id": cls.hidden.unique_id,
|
|
||||||
"component": "ChatHistoryWidget",
|
|
||||||
"props": {
|
|
||||||
"history": json.dumps(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"prompt": prompt,
|
|
||||||
"response": output_text,
|
|
||||||
"response_id": str(uuid.uuid4()),
|
|
||||||
"timestamp": time.time(),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
PromptServer.instance.send_sync(
|
|
||||||
"display_component",
|
|
||||||
render_spec,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(get_image_from_response(response), output_text)
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtension(ComfyExtension):
|
class GeminiExtension(ComfyExtension):
|
||||||
|
|||||||
@ -1,15 +1,10 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, Union
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from server import PromptServer
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import base64
|
import base64
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode):
|
|||||||
def create_input_message_contents(
|
def create_input_message_contents(
|
||||||
cls,
|
cls,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
image: Optional[torch.Tensor] = None,
|
image: torch.Tensor | None = None,
|
||||||
files: Optional[list[InputFileContent]] = None,
|
files: list[InputFileContent] | None = None,
|
||||||
) -> InputMessageContentList:
|
) -> InputMessageContentList:
|
||||||
"""Create a list of input message contents from prompt and optional image."""
|
"""Create a list of input message contents from prompt and optional image."""
|
||||||
content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [
|
content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
|
||||||
InputTextContent(text=prompt, type="input_text"),
|
InputTextContent(text=prompt, type="input_text"),
|
||||||
]
|
]
|
||||||
if image is not None:
|
if image is not None:
|
||||||
@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
persist_context: bool = False,
|
persist_context: bool = False,
|
||||||
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
|
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
|
||||||
images: Optional[torch.Tensor] = None,
|
images: torch.Tensor | None = None,
|
||||||
files: Optional[list[InputFileContent]] = None,
|
files: list[InputFileContent] | None = None,
|
||||||
advanced_options: Optional[CreateModelResponseProperties] = None,
|
advanced_options: CreateModelResponseProperties | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
|||||||
status_extractor=lambda response: response.status,
|
status_extractor=lambda response: response.status,
|
||||||
completed_statuses=["incomplete", "completed"]
|
completed_statuses=["incomplete", "completed"]
|
||||||
)
|
)
|
||||||
output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))
|
return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
|
||||||
|
|
||||||
# Update history
|
|
||||||
render_spec = {
|
|
||||||
"node_id": cls.hidden.unique_id,
|
|
||||||
"component": "ChatHistoryWidget",
|
|
||||||
"props": {
|
|
||||||
"history": json.dumps(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"prompt": prompt,
|
|
||||||
"response": output_text,
|
|
||||||
"response_id": str(uuid.uuid4()),
|
|
||||||
"timestamp": time.time(),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
PromptServer.instance.send_sync(
|
|
||||||
"display_component",
|
|
||||||
render_spec,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(output_text)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIInputFiles(IO.ComfyNode):
|
class OpenAIInputFiles(IO.ComfyNode):
|
||||||
@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode):
|
|||||||
def execute(
|
def execute(
|
||||||
cls,
|
cls,
|
||||||
truncation: bool,
|
truncation: bool,
|
||||||
instructions: Optional[str] = None,
|
instructions: str | None = None,
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: int | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
"""
|
"""
|
||||||
Configure advanced options for the OpenAI Chat Node.
|
Configure advanced options for the OpenAI Chat Node.
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.input_impl.video_types import VideoFromFile
|
||||||
@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import (
|
|||||||
VeoGenVidPollResponse,
|
VeoGenVidPollResponse,
|
||||||
VeoGenVidRequest,
|
VeoGenVidRequest,
|
||||||
VeoGenVidResponse,
|
VeoGenVidResponse,
|
||||||
|
VeoRequestInstance,
|
||||||
|
VeoRequestInstanceImage,
|
||||||
|
VeoRequestParameters,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Veo3FirstLastFrameNode",
|
||||||
|
display_name="Google Veo 3 First-Last-Frame to Video",
|
||||||
|
category="api node/video/Veo",
|
||||||
|
description="Generate video using prompt and first and last frames.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text description of the video",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "9:16"],
|
||||||
|
default="16:9",
|
||||||
|
tooltip="Aspect ratio of the output video",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=8,
|
||||||
|
min=4,
|
||||||
|
max=8,
|
||||||
|
step=2,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Duration of the output video in seconds",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFF,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for video generation",
|
||||||
|
),
|
||||||
|
IO.Image.Input("first_frame", tooltip="Start frame"),
|
||||||
|
IO.Image.Input("last_frame", tooltip="End frame"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
|
||||||
|
default="veo-3.1-fast-generate",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=True,
|
||||||
|
tooltip="Generate audio for the video.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
resolution: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
first_frame: torch.Tensor,
|
||||||
|
last_frame: torch.Tensor,
|
||||||
|
model: str,
|
||||||
|
generate_audio: bool,
|
||||||
|
):
|
||||||
|
model = MODELS_MAP[model]
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||||
|
response_model=VeoGenVidResponse,
|
||||||
|
data=VeoGenVidRequest(
|
||||||
|
instances=[
|
||||||
|
VeoRequestInstance(
|
||||||
|
prompt=prompt,
|
||||||
|
image=VeoRequestInstanceImage(
|
||||||
|
bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png"
|
||||||
|
),
|
||||||
|
lastFrame=VeoRequestInstanceImage(
|
||||||
|
bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
parameters=VeoRequestParameters(
|
||||||
|
aspectRatio=aspect_ratio,
|
||||||
|
personGeneration="ALLOW",
|
||||||
|
durationSeconds=duration,
|
||||||
|
enhancePrompt=True, # cannot be False for Veo3
|
||||||
|
seed=seed,
|
||||||
|
generateAudio=generate_audio,
|
||||||
|
negativePrompt=negative_prompt,
|
||||||
|
resolution=resolution,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
poll_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||||
|
response_model=VeoGenVidPollResponse,
|
||||||
|
status_extractor=lambda r: "completed" if r.done else "pending",
|
||||||
|
data=VeoGenVidPollRequest(
|
||||||
|
operationName=initial_response.name,
|
||||||
|
),
|
||||||
|
poll_interval=5.0,
|
||||||
|
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if poll_response.error:
|
||||||
|
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||||
|
|
||||||
|
response = poll_response.response
|
||||||
|
filtered_count = response.raiMediaFilteredCount
|
||||||
|
if filtered_count:
|
||||||
|
reasons = response.raiMediaFilteredReasons or []
|
||||||
|
reason_part = f": {reasons[0]}" if reasons else ""
|
||||||
|
raise Exception(
|
||||||
|
f"Content blocked by Google's Responsible AI filters{reason_part} "
|
||||||
|
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.videos:
|
||||||
|
video = response.videos[0]
|
||||||
|
if video.bytesBase64Encoded:
|
||||||
|
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||||
|
if video.gcsUri:
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||||
|
raise Exception("Video returned but no data or URL was provided")
|
||||||
|
raise Exception("Video generation completed but no video was returned")
|
||||||
|
|
||||||
|
|
||||||
class VeoExtension(ComfyExtension):
|
class VeoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
VeoVideoGenerationNode,
|
VeoVideoGenerationNode,
|
||||||
Veo3VideoGenerationNode,
|
Veo3VideoGenerationNode,
|
||||||
|
Veo3FirstLastFrameNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -48,8 +48,9 @@ async def upload_images_to_comfyapi(
|
|||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
*,
|
*,
|
||||||
max_images: int = 8,
|
max_images: int = 8,
|
||||||
mime_type: Optional[str] = None,
|
mime_type: str | None = None,
|
||||||
wait_label: Optional[str] = "Uploading",
|
wait_label: str | None = "Uploading",
|
||||||
|
show_batch_index: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Uploads images to ComfyUI API and returns download URLs.
|
Uploads images to ComfyUI API and returns download URLs.
|
||||||
@ -59,11 +60,18 @@ async def upload_images_to_comfyapi(
|
|||||||
download_urls: list[str] = []
|
download_urls: list[str] = []
|
||||||
is_batch = len(image.shape) > 3
|
is_batch = len(image.shape) > 3
|
||||||
batch_len = image.shape[0] if is_batch else 1
|
batch_len = image.shape[0] if is_batch else 1
|
||||||
|
num_to_upload = min(batch_len, max_images)
|
||||||
|
batch_start_ts = time.monotonic()
|
||||||
|
|
||||||
for idx in range(min(batch_len, max_images)):
|
for idx in range(num_to_upload):
|
||||||
tensor = image[idx] if is_batch else image
|
tensor = image[idx] if is_batch else image
|
||||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
|
||||||
|
effective_label = wait_label
|
||||||
|
if wait_label and show_batch_index and num_to_upload > 1:
|
||||||
|
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
|
||||||
|
|
||||||
|
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
|
||||||
download_urls.append(url)
|
download_urls.append(url)
|
||||||
return download_urls
|
return download_urls
|
||||||
|
|
||||||
@ -126,8 +134,9 @@ async def upload_file_to_comfyapi(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
file_bytes_io: BytesIO,
|
file_bytes_io: BytesIO,
|
||||||
filename: str,
|
filename: str,
|
||||||
upload_mime_type: Optional[str],
|
upload_mime_type: str | None,
|
||||||
wait_label: Optional[str] = "Uploading",
|
wait_label: str | None = "Uploading",
|
||||||
|
progress_origin_ts: float | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
||||||
if upload_mime_type is None:
|
if upload_mime_type is None:
|
||||||
@ -148,6 +157,7 @@ async def upload_file_to_comfyapi(
|
|||||||
file_bytes_io,
|
file_bytes_io,
|
||||||
content_type=upload_mime_type,
|
content_type=upload_mime_type,
|
||||||
wait_label=wait_label,
|
wait_label=wait_label,
|
||||||
|
progress_origin_ts=progress_origin_ts,
|
||||||
)
|
)
|
||||||
return create_resp.download_url
|
return create_resp.download_url
|
||||||
|
|
||||||
@ -155,27 +165,18 @@ async def upload_file_to_comfyapi(
|
|||||||
async def upload_file(
|
async def upload_file(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
upload_url: str,
|
upload_url: str,
|
||||||
file: Union[BytesIO, str],
|
file: BytesIO | str,
|
||||||
*,
|
*,
|
||||||
content_type: Optional[str] = None,
|
content_type: str | None = None,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff: float = 2.0,
|
retry_backoff: float = 2.0,
|
||||||
wait_label: Optional[str] = None,
|
wait_label: str | None = None,
|
||||||
|
progress_origin_ts: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
||||||
|
|
||||||
Args:
|
|
||||||
cls: Node class (provides auth context + UI progress hooks).
|
|
||||||
upload_url: Pre-signed PUT URL.
|
|
||||||
file: BytesIO or path string.
|
|
||||||
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
|
||||||
max_retries: Maximum retry attempts.
|
|
||||||
retry_delay: Initial delay in seconds.
|
|
||||||
retry_backoff: Exponential backoff factor.
|
|
||||||
wait_label: Progress label shown in Comfy UI.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
||||||
"""
|
"""
|
||||||
@ -198,7 +199,7 @@ async def upload_file(
|
|||||||
|
|
||||||
attempt = 0
|
attempt = 0
|
||||||
delay = retry_delay
|
delay = retry_delay
|
||||||
start_ts = time.monotonic()
|
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
|
||||||
op_uuid = uuid.uuid4().hex[:8]
|
op_uuid = uuid.uuid4().hex[:8]
|
||||||
while True:
|
while True:
|
||||||
attempt += 1
|
attempt += 1
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
1432
comfy_extras/nodes_dataset.py
Normal file
1432
comfy_extras/nodes_dataset.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
def normalize_path(path):
|
def normalize_path(path):
|
||||||
return path.replace('\\', '/')
|
return path.replace('\\', '/')
|
||||||
@ -34,58 +38,6 @@ class Load3D():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
|
|
||||||
|
|
||||||
FUNCTION = "process"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
CATEGORY = "3d"
|
|
||||||
|
|
||||||
def process(self, model_file, image, **kwargs):
|
|
||||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
|
||||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
|
||||||
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
|
||||||
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
|
|
||||||
|
|
||||||
load_image_node = nodes.LoadImage()
|
|
||||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
|
||||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
|
||||||
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
|
||||||
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
|
||||||
|
|
||||||
video = None
|
|
||||||
|
|
||||||
if image['recording'] != "":
|
|
||||||
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
|
||||||
|
|
||||||
video = VideoFromFile(recording_video_path)
|
|
||||||
|
|
||||||
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
|
|
||||||
|
|
||||||
class Load3DAnimation():
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
|
||||||
|
|
||||||
os.makedirs(input_dir, exist_ok=True)
|
|
||||||
|
|
||||||
input_path = Path(input_dir)
|
|
||||||
base_path = Path(folder_paths.get_input_directory())
|
|
||||||
|
|
||||||
files = [
|
|
||||||
normalize_path(str(file_path.relative_to(base_path)))
|
|
||||||
for file_path in input_path.rglob("*")
|
|
||||||
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"required": {
|
|
||||||
"model_file": (sorted(files), {"file_upload": True}),
|
|
||||||
"image": ("LOAD_3D_ANIMATION", {}),
|
|
||||||
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
|
||||||
|
|
||||||
@ -120,7 +72,8 @@ class Preview3D():
|
|||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
"model_file": ("STRING", {"default": "", "multiline": False}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"camera_info": ("LOAD3D_CAMERA", {})
|
"camera_info": ("LOAD3D_CAMERA", {}),
|
||||||
|
"bg_image": ("IMAGE", {})
|
||||||
}}
|
}}
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@ -133,50 +86,33 @@ class Preview3D():
|
|||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
def process(self, model_file, **kwargs):
|
||||||
camera_info = kwargs.get("camera_info", None)
|
camera_info = kwargs.get("camera_info", None)
|
||||||
|
bg_image = kwargs.get("bg_image", None)
|
||||||
|
|
||||||
|
bg_image_path = None
|
||||||
|
if bg_image is not None:
|
||||||
|
|
||||||
|
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(img_array)
|
||||||
|
|
||||||
|
temp_dir = folder_paths.get_temp_directory()
|
||||||
|
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||||
|
bg_image_path = os.path.join(temp_dir, filename)
|
||||||
|
img.save(bg_image_path, compress_level=1)
|
||||||
|
|
||||||
|
bg_image_path = f"temp/{filename}"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ui": {
|
"ui": {
|
||||||
"result": [model_file, camera_info]
|
"result": [model_file, camera_info, bg_image_path]
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class Preview3DAnimation():
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {
|
|
||||||
"model_file": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"camera_info": ("LOAD3D_CAMERA", {})
|
|
||||||
}}
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
|
|
||||||
CATEGORY = "3d"
|
|
||||||
|
|
||||||
FUNCTION = "process"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def process(self, model_file, **kwargs):
|
|
||||||
camera_info = kwargs.get("camera_info", None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ui": {
|
|
||||||
"result": [model_file, camera_info]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Load3D": Load3D,
|
"Load3D": Load3D,
|
||||||
"Load3DAnimation": Load3DAnimation,
|
|
||||||
"Preview3D": Preview3D,
|
"Preview3D": Preview3D,
|
||||||
"Preview3DAnimation": Preview3DAnimation
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"Load3D": "Load 3D",
|
"Load3D": "Load 3D & Animation",
|
||||||
"Load3DAnimation": "Load 3D - Animation",
|
"Preview3D": "Preview 3D & Animation",
|
||||||
"Preview3D": "Preview 3D",
|
|
||||||
"Preview3DAnimation": "Preview 3D - Animation"
|
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.71"
|
__version__ = "0.3.75"
|
||||||
|
|||||||
@ -37,13 +37,16 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
|
||||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||||
self.latent_rgb_factors_bias = None
|
self.latent_rgb_factors_bias = None
|
||||||
if latent_rgb_factors_bias is not None:
|
if latent_rgb_factors_bias is not None:
|
||||||
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||||
|
self.latent_rgb_factors_reshape = latent_rgb_factors_reshape
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
|
if self.latent_rgb_factors_reshape is not None:
|
||||||
|
x0 = self.latent_rgb_factors_reshape(x0)
|
||||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||||
if self.latent_rgb_factors_bias is not None:
|
if self.latent_rgb_factors_bias is not None:
|
||||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||||
@ -85,7 +88,7 @@ def get_previewer(device, latent_format):
|
|||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
if latent_format.latent_rgb_factors is not None:
|
if latent_format.latent_rgb_factors is not None:
|
||||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
def prepare_callback(model, steps, x0_output_dict=None):
|
def prepare_callback(model, steps, x0_output_dict=None):
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2278,6 +2278,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_images.py",
|
"nodes_images.py",
|
||||||
"nodes_video_model.py",
|
"nodes_video_model.py",
|
||||||
"nodes_train.py",
|
"nodes_train.py",
|
||||||
|
"nodes_dataset.py",
|
||||||
"nodes_sag.py",
|
"nodes_sag.py",
|
||||||
"nodes_perpneg.py",
|
"nodes_perpneg.py",
|
||||||
"nodes_stable3d.py",
|
"nodes_stable3d.py",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.71"
|
version = "0.3.75"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.30.6
|
comfyui-frontend-package==1.32.9
|
||||||
comfyui-workflow-templates==0.7.9
|
comfyui-workflow-templates==0.7.20
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
@ -174,7 +174,7 @@ def create_block_external_middleware():
|
|||||||
else:
|
else:
|
||||||
response = await handler(request)
|
response = await handler(request)
|
||||||
|
|
||||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
|
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return block_external_middleware
|
return block_external_middleware
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user