mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
Merge remote-tracking branch 'origin/master' into research-comfy
This commit is contained in:
commit
6fe024ff85
@ -139,9 +139,9 @@ Example:
|
|||||||
"_quantization_metadata": {
|
"_quantization_metadata": {
|
||||||
"format_version": "1.0",
|
"format_version": "1.0",
|
||||||
"layers": {
|
"layers": {
|
||||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
|
||||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
|
||||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
"model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
303
comfy/ldm/ernie/model.py
Normal file
303
comfy/ldm/ernie/model.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||||
|
assert dim % 2 == 0
|
||||||
|
if not comfy.model_management.supports_fp64(pos.device):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = pos.device
|
||||||
|
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||||
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
|
rot_dim = freqs_cis.shape[-1]
|
||||||
|
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||||
|
cos_ = freqs_cis[0]
|
||||||
|
sin_ = freqs_cis[1]
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||||
|
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||||
|
|
||||||
|
class ErnieImageEmbedND3(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: int, axes_dim: tuple):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = list(axes_dim)
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||||
|
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
|
||||||
|
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
|
||||||
|
|
||||||
|
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.proj(x)
|
||||||
|
batch_size, dim, height, width = x.shape
|
||||||
|
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
if self.flip_sin_to_cos:
|
||||||
|
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
Linear = operations.Linear
|
||||||
|
self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
class ErnieImageAttention(nn.Module):
|
||||||
|
def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.head_dim = dim_head
|
||||||
|
self.inner_dim = heads * dim_head
|
||||||
|
|
||||||
|
Linear = operations.Linear
|
||||||
|
RMSNorm = operations.RMSNorm
|
||||||
|
|
||||||
|
self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
|
||||||
|
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
B, S, _ = x.shape
|
||||||
|
|
||||||
|
q_flat = self.to_q(x)
|
||||||
|
k_flat = self.to_k(x)
|
||||||
|
v_flat = self.to_v(x)
|
||||||
|
|
||||||
|
query = q_flat.view(B, S, self.heads, self.head_dim)
|
||||||
|
key = k_flat.view(B, S, self.heads, self.head_dim)
|
||||||
|
|
||||||
|
query = self.norm_q(query)
|
||||||
|
key = self.norm_k(key)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
|
query, key = query.to(x.dtype), key.to(x.dtype)
|
||||||
|
|
||||||
|
q_flat = query.reshape(B, S, -1)
|
||||||
|
k_flat = key.reshape(B, S, -1)
|
||||||
|
|
||||||
|
hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask)
|
||||||
|
|
||||||
|
return self.to_out[0](hidden_states)
|
||||||
|
|
||||||
|
class ErnieImageFeedForward(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
Linear = operations.Linear
|
||||||
|
self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
|
||||||
|
|
||||||
|
class ErnieImageSharedAdaLNBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
RMSNorm = operations.RMSNorm
|
||||||
|
|
||||||
|
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
|
||||||
|
self.self_attention = ErnieImageAttention(
|
||||||
|
query_dim=hidden_size,
|
||||||
|
dim_head=hidden_size // num_heads,
|
||||||
|
heads=num_heads,
|
||||||
|
eps=eps,
|
||||||
|
operations=operations,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
|
||||||
|
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x_norm = self.adaLN_sa_ln(x)
|
||||||
|
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
||||||
|
|
||||||
|
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||||
|
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x_norm = self.adaLN_mlp_ln(x)
|
||||||
|
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
||||||
|
|
||||||
|
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
|
||||||
|
|
||||||
|
class ErnieImageAdaLNContinuous(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
LayerNorm = operations.LayerNorm
|
||||||
|
Linear = operations.Linear
|
||||||
|
self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
|
||||||
|
self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||||
|
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ErnieImageModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
num_attention_heads: int = 32,
|
||||||
|
num_layers: int = 36,
|
||||||
|
ffn_hidden_size: int = 12288,
|
||||||
|
in_channels: int = 128,
|
||||||
|
out_channels: int = 128,
|
||||||
|
patch_size: int = 1,
|
||||||
|
text_in_dim: int = 3072,
|
||||||
|
rope_theta: int = 256,
|
||||||
|
rope_axes_dim: tuple = (32, 48, 48),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
qk_layernorm: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = hidden_size // num_attention_heads
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
Linear = operations.Linear
|
||||||
|
|
||||||
|
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype)
|
||||||
|
self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False)
|
||||||
|
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype)
|
||||||
|
|
||||||
|
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype)
|
||||||
|
self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, **kwargs):
|
||||||
|
device, dtype = x.device, x.dtype
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
|
||||||
|
N_img = Hp * Wp
|
||||||
|
|
||||||
|
img_bsh = self.x_embedder(x)
|
||||||
|
|
||||||
|
text_bth = context
|
||||||
|
if self.text_proj is not None and text_bth.numel() > 0:
|
||||||
|
text_bth = self.text_proj(text_bth)
|
||||||
|
Tmax = text_bth.shape[1]
|
||||||
|
|
||||||
|
hidden_states = torch.cat([img_bsh, text_bth], dim=1)
|
||||||
|
|
||||||
|
text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32)
|
||||||
|
text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32)
|
||||||
|
index = float(Tmax)
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
|
||||||
|
h_len, w_len = float(Hp), float(Wp)
|
||||||
|
h_offset, w_offset = 0.0, 0.0
|
||||||
|
|
||||||
|
if rope_options is not None:
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
index += rope_options.get("shift_t", 0.0)
|
||||||
|
h_offset += rope_options.get("shift_y", 0.0)
|
||||||
|
w_offset += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32)
|
||||||
|
image_ids[:, :, 0] = image_ids[:, :, 1] + index
|
||||||
|
image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1)
|
||||||
|
image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0)
|
||||||
|
|
||||||
|
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
|
||||||
|
|
||||||
|
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
|
||||||
|
del image_ids, text_ids
|
||||||
|
|
||||||
|
sample = self.time_proj(timesteps.to(dtype)).to(self.time_embedding.linear_1.weight.dtype)
|
||||||
|
c = self.time_embedding(sample)
|
||||||
|
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||||
|
t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||||
|
]
|
||||||
|
|
||||||
|
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, rotary_pos_emb, temb)
|
||||||
|
|
||||||
|
hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states)
|
||||||
|
|
||||||
|
patches = self.final_linear(hidden_states)[:, :N_img, :]
|
||||||
|
output = (
|
||||||
|
patches.view(B, Hp, Wp, p, p, self.out_channels)
|
||||||
|
.permute(0, 5, 1, 3, 2, 4)
|
||||||
|
.contiguous()
|
||||||
|
.view(B, self.out_channels, H, W)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
|||||||
|
|
||||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
if not comfy.model_management.supports_fp64(pos.device):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
device = pos.device
|
device = pos.device
|
||||||
|
|||||||
@ -53,6 +53,7 @@ import comfy.ldm.kandinsky5.model
|
|||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
import comfy.ldm.ace.ace_step15
|
import comfy.ldm.ace.ace_step15
|
||||||
import comfy.ldm.rt_detr.rtdetr_v4
|
import comfy.ldm.rt_detr.rtdetr_v4
|
||||||
|
import comfy.ldm.ernie.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1962,3 +1963,14 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
class RT_DETR_v4(BaseModel):
|
class RT_DETR_v4(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||||
|
|
||||||
|
class ErnieImage(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|||||||
@ -713,6 +713,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "ernie"
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1732,6 +1732,21 @@ def supports_mxfp8_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_fp64(device=None):
|
||||||
|
if is_device_mps(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_intel_xpu():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_directml_enabled():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_ixuca():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
@ -62,6 +62,7 @@ import comfy.text_encoders.anima
|
|||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
import comfy.text_encoders.qwen35
|
import comfy.text_encoders.qwen35
|
||||||
|
import comfy.text_encoders.ernie
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -1235,6 +1236,7 @@ class TEModel(Enum):
|
|||||||
QWEN35_4B = 25
|
QWEN35_4B = 25
|
||||||
QWEN35_9B = 26
|
QWEN35_9B = 26
|
||||||
QWEN35_27B = 27
|
QWEN35_27B = 27
|
||||||
|
MINISTRAL_3_3B = 28
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1301,6 +1303,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.MISTRAL3_24B
|
return TEModel.MISTRAL3_24B
|
||||||
else:
|
else:
|
||||||
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
||||||
|
if weight.shape[0] == 3072:
|
||||||
|
return TEModel.MINISTRAL_3_3B
|
||||||
|
|
||||||
return TEModel.LLAMA3_8
|
return TEModel.LLAMA3_8
|
||||||
return None
|
return None
|
||||||
@ -1458,6 +1462,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.QWEN3_06B:
|
elif te_model == TEModel.QWEN3_06B:
|
||||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||||
|
elif te_model == TEModel.MINISTRAL_3_3B:
|
||||||
|
clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer
|
||||||
|
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import comfy.text_encoders.z_image
|
|||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
|
import comfy.text_encoders.ernie
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -1749,6 +1750,37 @@ class RT_DETR_v4(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
|
||||||
|
class ErnieImage(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "ernie",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1000.0,
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 10.0
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Flux2
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.ErnieImage(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}ministral3_3b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
38
comfy/text_encoders/ernie.py
Normal file
38
comfy/text_encoders/ernie.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from .flux import Mistral3Tokenizer
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
|
||||||
|
class Ministral3_3BTokenizer(Mistral3Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}):
|
||||||
|
return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
class ErnieTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="ministral3_3b", tokenizer=Mistral3Tokenizer)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||||
|
tokens = super().tokenize_with_weights(text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Ministral3_3BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
textmodel_json_config = {}
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ministral3_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, name="ministral3_3b", clip_model=Ministral3_3BModel):
|
||||||
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class ErnieTEModel_(ErnieTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return ErnieTEModel
|
||||||
@ -116,9 +116,9 @@ class MistralTokenizerClass:
|
|||||||
return LlamaTokenizerFast(**kwargs)
|
return LlamaTokenizerFast(**kwargs)
|
||||||
|
|
||||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}):
|
||||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"tekken_model": self.tekken_data}
|
return {"tekken_model": self.tekken_data}
|
||||||
|
|||||||
@ -60,6 +60,29 @@ class Mistral3Small24BConfig:
|
|||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Ministral3_3BConfig:
|
||||||
|
vocab_size: int = 131072
|
||||||
|
hidden_size: int = 3072
|
||||||
|
intermediate_size: int = 9216
|
||||||
|
num_hidden_layers: int = 26
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 262144
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = None
|
||||||
|
k_norm = None
|
||||||
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
lm_head: bool = False
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_3BConfig:
|
class Qwen25_3BConfig:
|
||||||
vocab_size: int = 151936
|
vocab_size: int = 151936
|
||||||
@ -946,6 +969,15 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Ministral3_3B(BaseLlama, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Ministral3_3BConfig(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -52,6 +52,26 @@ class TaskImageContent(BaseModel):
|
|||||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskVideoContentUrl(BaseModel):
|
||||||
|
url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskVideoContent(BaseModel):
|
||||||
|
type: str = Field("video_url")
|
||||||
|
video_url: TaskVideoContentUrl = Field(...)
|
||||||
|
role: str = Field("reference_video")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskAudioContentUrl(BaseModel):
|
||||||
|
url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskAudioContent(BaseModel):
|
||||||
|
type: str = Field("audio_url")
|
||||||
|
audio_url: TaskAudioContentUrl = Field(...)
|
||||||
|
role: str = Field("reference_audio")
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoTaskCreationRequest(BaseModel):
|
class Text2VideoTaskCreationRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||||
@ -64,6 +84,17 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
|||||||
generate_audio: bool | None = Field(...)
|
generate_audio: bool | None = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class Seedance2TaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1)
|
||||||
|
generate_audio: bool | None = Field(None)
|
||||||
|
resolution: str | None = Field(None)
|
||||||
|
ratio: str | None = Field(None)
|
||||||
|
duration: int | None = Field(None, ge=4, le=15)
|
||||||
|
seed: int | None = Field(None, ge=0, le=2147483647)
|
||||||
|
watermark: bool | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
class TaskCreationResponse(BaseModel):
|
||||||
id: str = Field(...)
|
id: str = Field(...)
|
||||||
|
|
||||||
@ -77,12 +108,27 @@ class TaskStatusResult(BaseModel):
|
|||||||
video_url: str = Field(...)
|
video_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusUsage(BaseModel):
|
||||||
|
completion_tokens: int = Field(0)
|
||||||
|
total_tokens: int = Field(0)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResponse(BaseModel):
|
class TaskStatusResponse(BaseModel):
|
||||||
id: str = Field(...)
|
id: str = Field(...)
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||||
error: TaskStatusError | None = Field(None)
|
error: TaskStatusError | None = Field(None)
|
||||||
content: TaskStatusResult | None = Field(None)
|
content: TaskStatusResult | None = Field(None)
|
||||||
|
usage: TaskStatusUsage | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||||
|
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||||
|
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||||
|
("dreamina-seedance-2-0-260128", True): 0.0043,
|
||||||
|
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
|
||||||
|
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED_PRESETS = [
|
RECOMMENDED_PRESETS = [
|
||||||
@ -112,6 +158,12 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
|||||||
("Custom", None, None),
|
("Custom", None, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Seedance 2.0 reference video pixel count limits per model.
|
||||||
|
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||||
|
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
|
||||||
|
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
|
||||||
|
}
|
||||||
|
|
||||||
# The time in this dictionary are given for 10 seconds duration.
|
# The time in this dictionary are given for 10 seconds duration.
|
||||||
VIDEO_TASKS_EXECUTION_TIME = {
|
VIDEO_TASKS_EXECUTION_TIME = {
|
||||||
"seedance-1-0-lite-t2v-250428": {
|
"seedance-1-0-lite-t2v-250428": {
|
||||||
|
|||||||
@ -8,16 +8,23 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
|||||||
from comfy_api_nodes.apis.bytedance import (
|
from comfy_api_nodes.apis.bytedance import (
|
||||||
RECOMMENDED_PRESETS,
|
RECOMMENDED_PRESETS,
|
||||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||||
|
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||||
|
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||||
VIDEO_TASKS_EXECUTION_TIME,
|
VIDEO_TASKS_EXECUTION_TIME,
|
||||||
Image2VideoTaskCreationRequest,
|
Image2VideoTaskCreationRequest,
|
||||||
ImageTaskCreationResponse,
|
ImageTaskCreationResponse,
|
||||||
|
Seedance2TaskCreationRequest,
|
||||||
Seedream4Options,
|
Seedream4Options,
|
||||||
Seedream4TaskCreationRequest,
|
Seedream4TaskCreationRequest,
|
||||||
|
TaskAudioContent,
|
||||||
|
TaskAudioContentUrl,
|
||||||
TaskCreationResponse,
|
TaskCreationResponse,
|
||||||
TaskImageContent,
|
TaskImageContent,
|
||||||
TaskImageContentUrl,
|
TaskImageContentUrl,
|
||||||
TaskStatusResponse,
|
TaskStatusResponse,
|
||||||
TaskTextContent,
|
TaskTextContent,
|
||||||
|
TaskVideoContent,
|
||||||
|
TaskVideoContentUrl,
|
||||||
Text2ImageTaskCreationRequest,
|
Text2ImageTaskCreationRequest,
|
||||||
Text2VideoTaskCreationRequest,
|
Text2VideoTaskCreationRequest,
|
||||||
)
|
)
|
||||||
@ -29,7 +36,10 @@ from comfy_api_nodes.util import (
|
|||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
|
upload_audio_to_comfyapi,
|
||||||
|
upload_image_to_comfyapi,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_string,
|
validate_string,
|
||||||
@ -46,12 +56,56 @@ SEEDREAM_MODELS = {
|
|||||||
# Long-running tasks endpoints(e.g., video)
|
# Long-running tasks endpoints(e.g., video)
|
||||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||||
|
BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/contents/generations/tasks" # + /{task_id}
|
||||||
|
|
||||||
|
SEEDANCE_MODELS = {
|
||||||
|
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||||
|
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||||
|
}
|
||||||
|
|
||||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
|
||||||
|
"""Validate reference video pixel count against Seedance 2.0 model limits."""
|
||||||
|
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||||
|
if not limits:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
w, h = video.get_dimensions()
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
pixels = w * h
|
||||||
|
min_px = limits.get("min")
|
||||||
|
max_px = limits.get("max")
|
||||||
|
if min_px and pixels < min_px:
|
||||||
|
raise ValueError(
|
||||||
|
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
|
||||||
|
)
|
||||||
|
if max_px and pixels > max_px:
|
||||||
|
raise ValueError(
|
||||||
|
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
|
||||||
|
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||||
|
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||||
|
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||||
|
if rate is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extractor(response: TaskStatusResponse) -> float | None:
|
||||||
|
if response.usage is None:
|
||||||
|
return None
|
||||||
|
return response.usage.total_tokens * 1.43 * rate / 1_000.0
|
||||||
|
|
||||||
|
return extractor
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||||
if response.error:
|
if response.error:
|
||||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||||
@ -335,8 +389,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
mp_provided = out_num_pixels / 1_000_000.0
|
mp_provided = out_num_pixels / 1_000_000.0
|
||||||
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Minimum image resolution for the selected model is 3.68MP, "
|
f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided."
|
||||||
f"but {mp_provided:.2f}MP provided."
|
|
||||||
)
|
)
|
||||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -952,33 +1005,6 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_video_task(
|
|
||||||
cls: type[IO.ComfyNode],
|
|
||||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
|
||||||
estimated_duration: int | None,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
if payload.model in DEPRECATED_MODELS:
|
|
||||||
logger.warning(
|
|
||||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
|
||||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
|
||||||
payload.model,
|
|
||||||
)
|
|
||||||
initial_response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
|
||||||
data=payload,
|
|
||||||
response_model=TaskCreationResponse,
|
|
||||||
)
|
|
||||||
response = await poll_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
|
||||||
status_extractor=lambda r: r.status,
|
|
||||||
estimated_duration=estimated_duration,
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
|
||||||
|
|
||||||
|
|
||||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||||
for i in text_params:
|
for i in text_params:
|
||||||
if f"--{i} " in prompt:
|
if f"--{i} " in prompt:
|
||||||
@ -1040,6 +1066,530 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _seedance2_text_inputs():
|
||||||
|
return [
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt for video generation.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["480p", "720p"],
|
||||||
|
tooltip="Resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"ratio",
|
||||||
|
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||||
|
tooltip="Aspect ratio of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=7,
|
||||||
|
min=4,
|
||||||
|
max=15,
|
||||||
|
step=1,
|
||||||
|
tooltip="Duration of the output video in seconds (4-15).",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=True,
|
||||||
|
tooltip="Enable audio generation for the output video.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ByteDance2TextToVideoNode",
|
||||||
|
display_name="ByteDance Seedance 2.0 Text to Video",
|
||||||
|
category="api node/video/ByteDance",
|
||||||
|
description="Generate video using Seedance 2.0 models based on a text prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
||||||
|
],
|
||||||
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add a watermark to the video.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$rate480 := 10044;
|
||||||
|
$rate720 := 21600;
|
||||||
|
$m := widgets.model;
|
||||||
|
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||||
|
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||||
|
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
watermark: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||||
|
model_id = SEEDANCE_MODELS[model["model"]]
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=Seedance2TaskCreationRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=[TaskTextContent(text=model["prompt"])],
|
||||||
|
generate_audio=model["generate_audio"],
|
||||||
|
resolution=model["resolution"],
|
||||||
|
ratio=model["ratio"],
|
||||||
|
duration=model["duration"],
|
||||||
|
seed=seed,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||||
|
poll_interval=9,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ByteDance2FirstLastFrameNode",
|
||||||
|
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
|
||||||
|
category="api node/video/ByteDance",
|
||||||
|
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
|
||||||
|
inputs=[
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
||||||
|
],
|
||||||
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"first_frame",
|
||||||
|
tooltip="First frame image for the video.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"last_frame",
|
||||||
|
tooltip="Last frame image for the video.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add a watermark to the video.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$rate480 := 10044;
|
||||||
|
$rate720 := 21600;
|
||||||
|
$m := widgets.model;
|
||||||
|
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||||
|
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||||
|
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: dict,
|
||||||
|
first_frame: Input.Image,
|
||||||
|
seed: int,
|
||||||
|
watermark: bool,
|
||||||
|
last_frame: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||||
|
model_id = SEEDANCE_MODELS[model["model"]]
|
||||||
|
|
||||||
|
content: list[TaskTextContent | TaskImageContent] = [
|
||||||
|
TaskTextContent(text=model["prompt"]),
|
||||||
|
TaskImageContent(
|
||||||
|
image_url=TaskImageContentUrl(
|
||||||
|
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||||
|
),
|
||||||
|
role="first_frame",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
if last_frame is not None:
|
||||||
|
content.append(
|
||||||
|
TaskImageContent(
|
||||||
|
image_url=TaskImageContentUrl(
|
||||||
|
url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.")
|
||||||
|
),
|
||||||
|
role="last_frame",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=Seedance2TaskCreationRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=content,
|
||||||
|
generate_audio=model["generate_audio"],
|
||||||
|
resolution=model["resolution"],
|
||||||
|
ratio=model["ratio"],
|
||||||
|
duration=model["duration"],
|
||||||
|
seed=seed,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||||
|
poll_interval=9,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
def _seedance2_reference_inputs():
|
||||||
|
return [
|
||||||
|
*_seedance2_text_inputs(),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Image.Input("reference_image"),
|
||||||
|
names=[
|
||||||
|
"image_1",
|
||||||
|
"image_2",
|
||||||
|
"image_3",
|
||||||
|
"image_4",
|
||||||
|
"image_5",
|
||||||
|
"image_6",
|
||||||
|
"image_7",
|
||||||
|
"image_8",
|
||||||
|
"image_9",
|
||||||
|
],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_videos",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Video.Input("reference_video"),
|
||||||
|
names=["video_1", "video_2", "video_3"],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_audios",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Audio.Input("reference_audio"),
|
||||||
|
names=["audio_1", "audio_2", "audio_3"],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ByteDance2ReferenceNode",
|
||||||
|
display_name="ByteDance Seedance 2.0 Reference to Video",
|
||||||
|
category="api node/video/ByteDance",
|
||||||
|
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
|
||||||
|
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
|
||||||
|
inputs=[
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs()),
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs()),
|
||||||
|
],
|
||||||
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add a watermark to the video.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["model", "model.resolution", "model.duration"],
|
||||||
|
input_groups=["model.reference_videos"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$rate480 := 10044;
|
||||||
|
$rate720 := 21600;
|
||||||
|
$m := widgets.model;
|
||||||
|
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
|
||||||
|
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
|
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$rate := $res = "720p" ? $rate720 : $rate480;
|
||||||
|
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
|
||||||
|
$minVideoFactor := $ceil($dur * 5 / 3);
|
||||||
|
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
|
||||||
|
$maxVideoCost := (15 + $dur) * $rate * $videoPricePer1K / 1000;
|
||||||
|
$hasVideo
|
||||||
|
? {
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": $minVideoCost,
|
||||||
|
"max_usd": $maxVideoCost,
|
||||||
|
"format": {"approximate": true}
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
"type": "usd",
|
||||||
|
"usd": $noVideoCost,
|
||||||
|
"format": {"approximate": true}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
watermark: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||||
|
|
||||||
|
reference_images = model.get("reference_images", {})
|
||||||
|
reference_videos = model.get("reference_videos", {})
|
||||||
|
reference_audios = model.get("reference_audios", {})
|
||||||
|
|
||||||
|
if not reference_images and not reference_videos:
|
||||||
|
raise ValueError("At least one reference image or video is required.")
|
||||||
|
|
||||||
|
model_id = SEEDANCE_MODELS[model["model"]]
|
||||||
|
has_video_input = len(reference_videos) > 0
|
||||||
|
total_video_duration = 0.0
|
||||||
|
for i, key in enumerate(reference_videos, 1):
|
||||||
|
video = reference_videos[key]
|
||||||
|
_validate_ref_video_pixels(video, model_id, i)
|
||||||
|
try:
|
||||||
|
dur = video.get_duration()
|
||||||
|
if dur < 1.8:
|
||||||
|
raise ValueError(f"Reference video {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||||
|
total_video_duration += dur
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if total_video_duration > 15.1:
|
||||||
|
raise ValueError(f"Total reference video duration is {total_video_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||||
|
|
||||||
|
total_audio_duration = 0.0
|
||||||
|
for i, key in enumerate(reference_audios, 1):
|
||||||
|
audio = reference_audios[key]
|
||||||
|
dur = int(audio["waveform"].shape[-1]) / int(audio["sample_rate"])
|
||||||
|
if dur < 1.8:
|
||||||
|
raise ValueError(f"Reference audio {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||||
|
total_audio_duration += dur
|
||||||
|
if total_audio_duration > 15.1:
|
||||||
|
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||||
|
|
||||||
|
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
|
||||||
|
TaskTextContent(text=model["prompt"]),
|
||||||
|
]
|
||||||
|
for i, key in enumerate(reference_images, 1):
|
||||||
|
content.append(
|
||||||
|
TaskImageContent(
|
||||||
|
image_url=TaskImageContentUrl(
|
||||||
|
url=await upload_image_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
image=reference_images[key],
|
||||||
|
wait_label=f"Uploading image {i}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
role="reference_image",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for i, key in enumerate(reference_videos, 1):
|
||||||
|
content.append(
|
||||||
|
TaskVideoContent(
|
||||||
|
video_url=TaskVideoContentUrl(
|
||||||
|
url=await upload_video_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
reference_videos[key],
|
||||||
|
wait_label=f"Uploading video {i}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for key in reference_audios:
|
||||||
|
content.append(
|
||||||
|
TaskAudioContent(
|
||||||
|
audio_url=TaskAudioContentUrl(
|
||||||
|
url=await upload_audio_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
reference_audios[key],
|
||||||
|
container_format="mp3",
|
||||||
|
codec_name="libmp3lame",
|
||||||
|
mime_type="audio/mpeg",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=Seedance2TaskCreationRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=content,
|
||||||
|
generate_audio=model["generate_audio"],
|
||||||
|
resolution=model["resolution"],
|
||||||
|
ratio=model["ratio"],
|
||||||
|
duration=model["duration"],
|
||||||
|
seed=seed,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
||||||
|
poll_interval=9,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
async def process_video_task(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||||
|
estimated_duration: int | None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
if payload.model in DEPRECATED_MODELS:
|
||||||
|
logger.warning(
|
||||||
|
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||||
|
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||||
|
payload.model,
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=payload,
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
estimated_duration=estimated_duration,
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
class ByteDanceExtension(ComfyExtension):
|
class ByteDanceExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -1050,6 +1600,9 @@ class ByteDanceExtension(ComfyExtension):
|
|||||||
ByteDanceImageToVideoNode,
|
ByteDanceImageToVideoNode,
|
||||||
ByteDanceFirstLastFrameNode,
|
ByteDanceFirstLastFrameNode,
|
||||||
ByteDanceImageReferenceNode,
|
ByteDanceImageReferenceNode,
|
||||||
|
ByteDance2TextToVideoNode,
|
||||||
|
ByteDance2FirstLastFrameNode,
|
||||||
|
ByteDance2ReferenceNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.18.1"
|
__version__ = "0.19.0"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.18.1"
|
version = "0.19.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.42.10
|
comfyui-frontend-package==1.42.10
|
||||||
comfyui-workflow-templates==0.9.45
|
comfyui-workflow-templates==0.9.47
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user