mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-23 17:02:38 +08:00
Merge remote-tracking branch 'origin/worksplit-multigpu' into worksplit-multigpu-wip
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
commit
bae191c294
2
.ci/windows_intel_base_files/run_intel_gpu.bat
Executable file
2
.ci/windows_intel_base_files/run_intel_gpu.bat
Executable file
@ -0,0 +1,2 @@
|
|||||||
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||||
|
pause
|
||||||
@ -139,9 +139,9 @@ Example:
|
|||||||
"_quantization_metadata": {
|
"_quantization_metadata": {
|
||||||
"format_version": "1.0",
|
"format_version": "1.0",
|
||||||
"layers": {
|
"layers": {
|
||||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
|
||||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
|
||||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
"model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -165,4 +165,4 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
|
|||||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||||
|
|
||||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||||
|
|||||||
@ -195,7 +195,9 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
|||||||
|
|
||||||
#### Alternative Downloads:
|
#### Alternative Downloads:
|
||||||
|
|
||||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||||
|
|
||||||
|
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class InternalRoutes:
|
|||||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||||
key=lambda entry: -entry.stat().st_mtime
|
key=lambda entry: -entry.stat().st_mtime
|
||||||
)
|
)
|
||||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200)
|
||||||
|
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
|
|||||||
@ -182,7 +182,7 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"widgets_values": [
|
"widgets_values": [
|
||||||
50
|
0
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@ -316,7 +316,7 @@
|
|||||||
"step": 1
|
"step": 1
|
||||||
},
|
},
|
||||||
"widgets_values": [
|
"widgets_values": [
|
||||||
30
|
0
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
301
comfy/ldm/ernie/model.py
Normal file
301
comfy/ldm/ernie/model.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||||
|
assert dim % 2 == 0
|
||||||
|
if not comfy.model_management.supports_fp64(pos.device):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = pos.device
|
||||||
|
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
out = torch.einsum("...n,d->...nd", pos.to(device), omega)
|
||||||
|
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||||
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
|
rot_dim = freqs_cis.shape[-1]
|
||||||
|
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||||
|
cos_ = freqs_cis[0]
|
||||||
|
sin_ = freqs_cis[1]
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||||
|
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||||
|
|
||||||
|
class ErnieImageEmbedND3(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: int, axes_dim: tuple):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = list(axes_dim)
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||||
|
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
|
||||||
|
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
|
||||||
|
|
||||||
|
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.proj(x)
|
||||||
|
batch_size, dim, height, width = x.shape
|
||||||
|
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||||
|
half_dim = self.num_channels // 2
|
||||||
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
|
||||||
|
emb = torch.exp(exponent)
|
||||||
|
emb = timesteps[:, None].float() * emb[None, :]
|
||||||
|
if self.flip_sin_to_cos:
|
||||||
|
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
Linear = operations.Linear
|
||||||
|
self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
class ErnieImageAttention(nn.Module):
|
||||||
|
def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.head_dim = dim_head
|
||||||
|
self.inner_dim = heads * dim_head
|
||||||
|
|
||||||
|
Linear = operations.Linear
|
||||||
|
RMSNorm = operations.RMSNorm
|
||||||
|
|
||||||
|
self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
|
||||||
|
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
B, S, _ = x.shape
|
||||||
|
|
||||||
|
q_flat = self.to_q(x)
|
||||||
|
k_flat = self.to_k(x)
|
||||||
|
v_flat = self.to_v(x)
|
||||||
|
|
||||||
|
query = q_flat.view(B, S, self.heads, self.head_dim)
|
||||||
|
key = k_flat.view(B, S, self.heads, self.head_dim)
|
||||||
|
|
||||||
|
query = self.norm_q(query)
|
||||||
|
key = self.norm_k(key)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
|
q_flat = query.reshape(B, S, -1)
|
||||||
|
k_flat = key.reshape(B, S, -1)
|
||||||
|
|
||||||
|
hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask)
|
||||||
|
|
||||||
|
return self.to_out[0](hidden_states)
|
||||||
|
|
||||||
|
class ErnieImageFeedForward(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
Linear = operations.Linear
|
||||||
|
self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
|
||||||
|
|
||||||
|
class ErnieImageSharedAdaLNBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
RMSNorm = operations.RMSNorm
|
||||||
|
|
||||||
|
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
|
||||||
|
self.self_attention = ErnieImageAttention(
|
||||||
|
query_dim=hidden_size,
|
||||||
|
dim_head=hidden_size // num_heads,
|
||||||
|
heads=num_heads,
|
||||||
|
eps=eps,
|
||||||
|
operations=operations,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype)
|
||||||
|
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x_norm = self.adaLN_sa_ln(x)
|
||||||
|
x_norm = x_norm * (1 + scale_msa) + shift_msa
|
||||||
|
|
||||||
|
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||||
|
x = residual + gate_msa * attn_out
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x_norm = self.adaLN_mlp_ln(x)
|
||||||
|
x_norm = x_norm * (1 + scale_mlp) + shift_mlp
|
||||||
|
|
||||||
|
return residual + gate_mlp * self.mlp(x_norm)
|
||||||
|
|
||||||
|
class ErnieImageAdaLNContinuous(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
LayerNorm = operations.LayerNorm
|
||||||
|
Linear = operations.Linear
|
||||||
|
self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
|
||||||
|
self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||||
|
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ErnieImageModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
num_attention_heads: int = 32,
|
||||||
|
num_layers: int = 36,
|
||||||
|
ffn_hidden_size: int = 12288,
|
||||||
|
in_channels: int = 128,
|
||||||
|
out_channels: int = 128,
|
||||||
|
patch_size: int = 1,
|
||||||
|
text_in_dim: int = 3072,
|
||||||
|
rope_theta: int = 256,
|
||||||
|
rope_axes_dim: tuple = (32, 48, 48),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
qk_layernorm: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = hidden_size // num_attention_heads
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
Linear = operations.Linear
|
||||||
|
|
||||||
|
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype)
|
||||||
|
self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False)
|
||||||
|
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype)
|
||||||
|
|
||||||
|
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
|
||||||
|
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype)
|
||||||
|
self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, **kwargs):
|
||||||
|
device, dtype = x.device, x.dtype
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
|
||||||
|
N_img = Hp * Wp
|
||||||
|
|
||||||
|
img_bsh = self.x_embedder(x)
|
||||||
|
|
||||||
|
text_bth = context
|
||||||
|
if self.text_proj is not None and text_bth.numel() > 0:
|
||||||
|
text_bth = self.text_proj(text_bth)
|
||||||
|
Tmax = text_bth.shape[1]
|
||||||
|
|
||||||
|
hidden_states = torch.cat([img_bsh, text_bth], dim=1)
|
||||||
|
|
||||||
|
text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32)
|
||||||
|
text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32)
|
||||||
|
index = float(Tmax)
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
|
||||||
|
h_len, w_len = float(Hp), float(Wp)
|
||||||
|
h_offset, w_offset = 0.0, 0.0
|
||||||
|
|
||||||
|
if rope_options is not None:
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
index += rope_options.get("shift_t", 0.0)
|
||||||
|
h_offset += rope_options.get("shift_y", 0.0)
|
||||||
|
w_offset += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32)
|
||||||
|
image_ids[:, :, 0] = image_ids[:, :, 1] + index
|
||||||
|
image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1)
|
||||||
|
image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0)
|
||||||
|
|
||||||
|
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
|
||||||
|
|
||||||
|
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
|
||||||
|
del image_ids, text_ids
|
||||||
|
|
||||||
|
sample = self.time_proj(timesteps).to(dtype)
|
||||||
|
c = self.time_embedding(sample)
|
||||||
|
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||||
|
t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||||
|
]
|
||||||
|
|
||||||
|
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, rotary_pos_emb, temb)
|
||||||
|
|
||||||
|
hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states)
|
||||||
|
|
||||||
|
patches = self.final_linear(hidden_states)[:, :N_img, :]
|
||||||
|
output = (
|
||||||
|
patches.view(B, Hp, Wp, p, p, self.out_channels)
|
||||||
|
.permute(0, 5, 1, 3, 2, 4)
|
||||||
|
.contiguous()
|
||||||
|
.view(B, self.out_channels, H, W)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
|||||||
|
|
||||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
if not comfy.model_management.supports_fp64(pos.device):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
device = pos.device
|
device = pos.device
|
||||||
|
|||||||
@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
|
|||||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||||
for layer in ts:
|
for layer in ts:
|
||||||
|
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||||
|
found_patched = False
|
||||||
|
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||||
|
if isinstance(layer, class_type):
|
||||||
|
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||||
|
found_patched = True
|
||||||
|
break
|
||||||
|
if found_patched:
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(layer, VideoResBlock):
|
if isinstance(layer, VideoResBlock):
|
||||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
elif isinstance(layer, TimestepBlock):
|
elif isinstance(layer, TimestepBlock):
|
||||||
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
|||||||
elif isinstance(layer, Upsample):
|
elif isinstance(layer, Upsample):
|
||||||
x = layer(x, output_shape=output_shape)
|
x = layer(x, output_shape=output_shape)
|
||||||
else:
|
else:
|
||||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
|
||||||
found_patched = False
|
|
||||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
|
||||||
if isinstance(layer, class_type):
|
|
||||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
|
||||||
found_patched = True
|
|
||||||
break
|
|
||||||
if found_patched:
|
|
||||||
continue
|
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
|
|||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = apply_control(h, control, 'middle')
|
h = apply_control(h, control, 'middle')
|
||||||
|
|
||||||
|
if "middle_block_after_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["middle_block_after_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
|
||||||
|
"timesteps": timesteps, "transformer_options": transformer_options})
|
||||||
|
h = out["h"]
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
h, hsp = p(h, hsp, transformer_options)
|
h, hsp = p(h, hsp, transformer_options)
|
||||||
|
|
||||||
h = th.cat([h, hsp], dim=1)
|
if hsp is not None:
|
||||||
del hsp
|
h = th.cat([h, hsp], dim=1)
|
||||||
|
del hsp
|
||||||
if len(hs) > 0:
|
if len(hs) > 0:
|
||||||
output_shape = hs[-1].shape
|
output_shape = hs[-1].shape
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -90,7 +90,7 @@ class HeatmapHead(torch.nn.Module):
|
|||||||
origin_max = np.max(hm[k])
|
origin_max = np.max(hm[k])
|
||||||
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
||||||
dr[border:-border, border:-border] = hm[k].copy()
|
dr[border:-border, border:-border] = hm[k].copy()
|
||||||
dr = gaussian_filter(dr, sigma=2.0)
|
dr = gaussian_filter(dr, sigma=2.0, truncate=2.5)
|
||||||
hm[k] = dr[border:-border, border:-border].copy()
|
hm[k] = dr[border:-border, border:-border].copy()
|
||||||
cur_max = np.max(hm[k])
|
cur_max = np.max(hm[k])
|
||||||
if cur_max > 0:
|
if cur_max > 0:
|
||||||
|
|||||||
0
comfy/ldm/supir/__init__.py
Normal file
0
comfy/ldm/supir/__init__.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroSFT(nn.Module):
|
||||||
|
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
ks = 3
|
||||||
|
pw = ks // 2
|
||||||
|
|
||||||
|
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
nhidden = 128
|
||||||
|
|
||||||
|
self.mlp_shared = nn.Sequential(
|
||||||
|
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
|
||||||
|
self.pre_concat = bool(concat_channels != 0)
|
||||||
|
|
||||||
|
def forward(self, c, h, h_ori=None, control_scale=1):
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h_raw = torch.cat([h_ori, h], dim=1)
|
||||||
|
else:
|
||||||
|
h_raw = h
|
||||||
|
|
||||||
|
h = h + self.zero_conv(c)
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
actv = self.mlp_shared(c)
|
||||||
|
gamma = self.zero_mul(actv)
|
||||||
|
beta = self.zero_add(actv)
|
||||||
|
h = self.param_free_norm(h)
|
||||||
|
h = torch.addcmul(h + beta, h, gamma)
|
||||||
|
if h_ori is not None and not self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
return torch.lerp(h_raw, h, control_scale)
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossAttnInner(nn.Module):
|
||||||
|
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
|
||||||
|
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, context):
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
return self.to_out(optimized_attention(q, k, v, self.heads))
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroCrossAttn(nn.Module):
|
||||||
|
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
heads = query_dim // 64
|
||||||
|
dim_head = 64
|
||||||
|
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, context, x, control_scale=1):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
|
||||||
|
x = self.attn(
|
||||||
|
self.norm1(x).flatten(2).transpose(1, 2),
|
||||||
|
self.norm2(context).flatten(2).transpose(1, 2),
|
||||||
|
).transpose(1, 2).unflatten(2, (h, w))
|
||||||
|
|
||||||
|
return x_in + x * control_scale
|
||||||
|
|
||||||
|
|
||||||
|
class GLVControl(nn.Module):
|
||||||
|
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=4,
|
||||||
|
model_channels=320,
|
||||||
|
num_res_blocks=2,
|
||||||
|
attention_resolutions=(4, 2),
|
||||||
|
channel_mult=(1, 2, 4),
|
||||||
|
num_head_channels=64,
|
||||||
|
transformer_depth=(1, 2, 10),
|
||||||
|
context_dim=2048,
|
||||||
|
adm_in_channels=2816,
|
||||||
|
use_linear_in_transformer=True,
|
||||||
|
use_checkpoint=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_channels = model_channels
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
])
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for nr in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
layers.append(
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[level], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ds *= 2
|
||||||
|
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[-1], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations),
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
|
emb = self.time_embed(t_emb) + self.label_emb(y)
|
||||||
|
|
||||||
|
guided_hint = self.input_hint_block(x, emb, context)
|
||||||
|
|
||||||
|
hs = []
|
||||||
|
h = xt
|
||||||
|
for module in self.input_blocks:
|
||||||
|
if guided_hint is not None:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
h += guided_hint
|
||||||
|
guided_hint = None
|
||||||
|
else:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
h = self.middle_block(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIR(nn.Module):
|
||||||
|
"""
|
||||||
|
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
|
||||||
|
State dict keys match the original SUPIR checkpoint layout:
|
||||||
|
control_model.* -> GLVControl
|
||||||
|
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
"""
|
||||||
|
def __init__(self, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
project_channel_scale = 2
|
||||||
|
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
||||||
|
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
|
||||||
|
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
||||||
|
cross_attn_insert_idx = [6, 3]
|
||||||
|
|
||||||
|
self.project_modules = nn.ModuleList()
|
||||||
|
for i in range(len(cond_output_channels)):
|
||||||
|
self.project_modules.append(ZeroSFT(
|
||||||
|
project_channels[i], cond_output_channels[i],
|
||||||
|
concat_channels=concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
|
|
||||||
|
for i in cross_attn_insert_idx:
|
||||||
|
self.project_modules.insert(i, ZeroCrossAttn(
|
||||||
|
cond_output_channels[i], concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
103
comfy/ldm/supir/supir_patch.py
Normal file
103
comfy/ldm/supir/supir_patch.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRPatch:
|
||||||
|
"""
|
||||||
|
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
|
||||||
|
Runs GLVControl lazily on first patch invocation per step, applies adapters through
|
||||||
|
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
|
||||||
|
"""
|
||||||
|
SIGMA_MAX = 14.6146
|
||||||
|
|
||||||
|
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
|
||||||
|
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
|
||||||
|
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
self.hint_latent = hint_latent # encoded LQ image latent
|
||||||
|
self.strength_start = strength_start
|
||||||
|
self.strength_end = strength_end
|
||||||
|
self.cached_features = None
|
||||||
|
self.adapter_idx = 0
|
||||||
|
self.control_idx = 0
|
||||||
|
self.current_control_idx = 0
|
||||||
|
self.active = True
|
||||||
|
|
||||||
|
def _ensure_features(self, kwargs):
|
||||||
|
"""Run GLVControl on first call per step, cache results."""
|
||||||
|
if self.cached_features is not None:
|
||||||
|
return
|
||||||
|
x = kwargs["x"]
|
||||||
|
b = x.shape[0]
|
||||||
|
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
|
||||||
|
if hint.shape[0] != b:
|
||||||
|
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
|
||||||
|
self.cached_features = self.model_patch.model.control_model(
|
||||||
|
hint, kwargs["timesteps"], x,
|
||||||
|
kwargs["context"], kwargs["y"]
|
||||||
|
)
|
||||||
|
self.adapter_idx = len(self.project_modules) - 1
|
||||||
|
self.control_idx = len(self.cached_features) - 1
|
||||||
|
|
||||||
|
def _get_control_scale(self, kwargs):
|
||||||
|
if self.strength_start == self.strength_end:
|
||||||
|
return self.strength_end
|
||||||
|
sigma = kwargs["transformer_options"].get("sigmas")
|
||||||
|
if sigma is None:
|
||||||
|
return self.strength_end
|
||||||
|
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
|
||||||
|
t = min(s / self.SIGMA_MAX, 1.0)
|
||||||
|
return t * (self.strength_start - self.strength_end) + self.strength_end
|
||||||
|
|
||||||
|
def middle_after(self, kwargs):
|
||||||
|
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
|
||||||
|
self.cached_features = None # reset from previous step
|
||||||
|
self.current_scale = self._get_control_scale(kwargs)
|
||||||
|
self.active = self.current_scale > 0
|
||||||
|
if not self.active:
|
||||||
|
return {"h": kwargs["h"]}
|
||||||
|
self._ensure_features(kwargs)
|
||||||
|
h = kwargs["h"]
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return {"h": h}
|
||||||
|
|
||||||
|
def output_block(self, h, hsp, transformer_options):
|
||||||
|
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
|
||||||
|
if not self.active:
|
||||||
|
return h, hsp
|
||||||
|
self.current_control_idx = self.control_idx
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return h, None
|
||||||
|
|
||||||
|
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
|
||||||
|
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
|
||||||
|
block_type, _ = transformer_options["block"]
|
||||||
|
if block_type == "output" and self.active and self.cached_features is not None:
|
||||||
|
x = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
return layer(x, output_shape=output_shape)
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.cached_features = None
|
||||||
|
if self.hint_latent is not None:
|
||||||
|
self.hint_latent = self.hint_latent.to(device_or_dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
|
def register(self, model_patcher):
|
||||||
|
"""Register all patches on a cloned model patcher."""
|
||||||
|
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
|
||||||
|
model_patcher.set_model_output_block_patch(self.output_block)
|
||||||
|
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")
|
||||||
@ -53,6 +53,7 @@ import comfy.ldm.kandinsky5.model
|
|||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
import comfy.ldm.ace.ace_step15
|
import comfy.ldm.ace.ace_step15
|
||||||
import comfy.ldm.rt_detr.rtdetr_v4
|
import comfy.ldm.rt_detr.rtdetr_v4
|
||||||
|
import comfy.ldm.ernie.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1962,3 +1963,14 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
class RT_DETR_v4(BaseModel):
|
class RT_DETR_v4(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||||
|
|
||||||
|
class ErnieImage(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|||||||
@ -713,6 +713,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "ernie"
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1762,6 +1762,21 @@ def supports_mxfp8_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_fp64(device=None):
|
||||||
|
if is_device_mps(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_intel_xpu():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_directml_enabled():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_ixuca():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
@ -595,6 +595,10 @@ class ModelPatcher:
|
|||||||
def set_model_noise_refiner_patch(self, patch):
|
def set_model_noise_refiner_patch(self, patch):
|
||||||
self.set_model_patch(patch, "noise_refiner")
|
self.set_model_patch(patch, "noise_refiner")
|
||||||
|
|
||||||
|
def set_model_middle_block_after_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "middle_block_after_patch")
|
||||||
|
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
|
|||||||
@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
p = fn(param)
|
p = fn(param)
|
||||||
if p.is_inference():
|
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||||
p = p.clone()
|
p = p.clone()
|
||||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
for key, buf in self._buffers.items():
|
for key, buf in self._buffers.items():
|
||||||
|
|||||||
@ -62,6 +62,7 @@ import comfy.text_encoders.anima
|
|||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
import comfy.text_encoders.qwen35
|
import comfy.text_encoders.qwen35
|
||||||
|
import comfy.text_encoders.ernie
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -1235,6 +1236,7 @@ class TEModel(Enum):
|
|||||||
QWEN35_4B = 25
|
QWEN35_4B = 25
|
||||||
QWEN35_9B = 26
|
QWEN35_9B = 26
|
||||||
QWEN35_27B = 27
|
QWEN35_27B = 27
|
||||||
|
MINISTRAL_3_3B = 28
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1301,6 +1303,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.MISTRAL3_24B
|
return TEModel.MISTRAL3_24B
|
||||||
else:
|
else:
|
||||||
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
||||||
|
if weight.shape[0] == 3072:
|
||||||
|
return TEModel.MINISTRAL_3_3B
|
||||||
|
|
||||||
return TEModel.LLAMA3_8
|
return TEModel.LLAMA3_8
|
||||||
return None
|
return None
|
||||||
@ -1458,6 +1462,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.QWEN3_06B:
|
elif te_model == TEModel.QWEN3_06B:
|
||||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||||
|
elif te_model == TEModel.MINISTRAL_3_3B:
|
||||||
|
clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer
|
||||||
|
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||||
else:
|
else:
|
||||||
# clip_l
|
# clip_l
|
||||||
if clip_type == CLIPType.SD3:
|
if clip_type == CLIPType.SD3:
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import comfy.text_encoders.z_image
|
|||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
|
import comfy.text_encoders.ernie
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -1749,6 +1750,37 @@ class RT_DETR_v4(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
|
||||||
|
class ErnieImage(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "ernie",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1000.0,
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 10.0
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Flux2
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.ErnieImage(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}ministral3_3b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
38
comfy/text_encoders/ernie.py
Normal file
38
comfy/text_encoders/ernie.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from .flux import Mistral3Tokenizer
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
|
||||||
|
class Ministral3_3BTokenizer(Mistral3Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}):
|
||||||
|
return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
class ErnieTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="ministral3_3b", tokenizer=Mistral3Tokenizer)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||||
|
tokens = super().tokenize_with_weights(text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Ministral3_3BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
textmodel_json_config = {}
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ministral3_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, name="ministral3_3b", clip_model=Ministral3_3BModel):
|
||||||
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class ErnieTEModel_(ErnieTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return ErnieTEModel_
|
||||||
@ -116,9 +116,9 @@ class MistralTokenizerClass:
|
|||||||
return LlamaTokenizerFast(**kwargs)
|
return LlamaTokenizerFast(**kwargs)
|
||||||
|
|
||||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}):
|
||||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"tekken_model": self.tekken_data}
|
return {"tekken_model": self.tekken_data}
|
||||||
|
|||||||
@ -60,6 +60,30 @@ class Mistral3Small24BConfig:
|
|||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Ministral3_3BConfig:
|
||||||
|
vocab_size: int = 131072
|
||||||
|
hidden_size: int = 3072
|
||||||
|
intermediate_size: int = 9216
|
||||||
|
num_hidden_layers: int = 26
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
max_position_embeddings: int = 262144
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
transformer_type: str = "llama"
|
||||||
|
head_dim = 128
|
||||||
|
rms_norm_add = False
|
||||||
|
mlp_activation = "silu"
|
||||||
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
q_norm = None
|
||||||
|
k_norm = None
|
||||||
|
rope_scale = None
|
||||||
|
final_norm: bool = True
|
||||||
|
lm_head: bool = False
|
||||||
|
stop_tokens = [2]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen25_3BConfig:
|
class Qwen25_3BConfig:
|
||||||
vocab_size: int = 151936
|
vocab_size: int = 151936
|
||||||
@ -946,6 +970,15 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
class Ministral3_3B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Ministral3_3BConfig(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -52,6 +52,26 @@ class TaskImageContent(BaseModel):
|
|||||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskVideoContentUrl(BaseModel):
|
||||||
|
url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskVideoContent(BaseModel):
|
||||||
|
type: str = Field("video_url")
|
||||||
|
video_url: TaskVideoContentUrl = Field(...)
|
||||||
|
role: str = Field("reference_video")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskAudioContentUrl(BaseModel):
|
||||||
|
url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskAudioContent(BaseModel):
|
||||||
|
type: str = Field("audio_url")
|
||||||
|
audio_url: TaskAudioContentUrl = Field(...)
|
||||||
|
role: str = Field("reference_audio")
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoTaskCreationRequest(BaseModel):
|
class Text2VideoTaskCreationRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||||
@ -64,6 +84,17 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
|||||||
generate_audio: bool | None = Field(...)
|
generate_audio: bool | None = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class Seedance2TaskCreationRequest(BaseModel):
|
||||||
|
model: str = Field(...)
|
||||||
|
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1)
|
||||||
|
generate_audio: bool | None = Field(None)
|
||||||
|
resolution: str | None = Field(None)
|
||||||
|
ratio: str | None = Field(None)
|
||||||
|
duration: int | None = Field(None, ge=4, le=15)
|
||||||
|
seed: int | None = Field(None, ge=0, le=2147483647)
|
||||||
|
watermark: bool | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
class TaskCreationResponse(BaseModel):
|
||||||
id: str = Field(...)
|
id: str = Field(...)
|
||||||
|
|
||||||
@ -77,12 +108,27 @@ class TaskStatusResult(BaseModel):
|
|||||||
video_url: str = Field(...)
|
video_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusUsage(BaseModel):
|
||||||
|
completion_tokens: int = Field(0)
|
||||||
|
total_tokens: int = Field(0)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResponse(BaseModel):
|
class TaskStatusResponse(BaseModel):
|
||||||
id: str = Field(...)
|
id: str = Field(...)
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||||
error: TaskStatusError | None = Field(None)
|
error: TaskStatusError | None = Field(None)
|
||||||
content: TaskStatusResult | None = Field(None)
|
content: TaskStatusResult | None = Field(None)
|
||||||
|
usage: TaskStatusUsage | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||||
|
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||||
|
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||||
|
("dreamina-seedance-2-0-260128", True): 0.0043,
|
||||||
|
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
|
||||||
|
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED_PRESETS = [
|
RECOMMENDED_PRESETS = [
|
||||||
@ -112,6 +158,12 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
|||||||
("Custom", None, None),
|
("Custom", None, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Seedance 2.0 reference video pixel count limits per model.
|
||||||
|
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||||
|
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
|
||||||
|
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
|
||||||
|
}
|
||||||
|
|
||||||
# The time in this dictionary are given for 10 seconds duration.
|
# The time in this dictionary are given for 10 seconds duration.
|
||||||
VIDEO_TASKS_EXECUTION_TIME = {
|
VIDEO_TASKS_EXECUTION_TIME = {
|
||||||
"seedance-1-0-lite-t2v-250428": {
|
"seedance-1-0-lite-t2v-250428": {
|
||||||
|
|||||||
@ -8,16 +8,23 @@ from comfy_api.latest import IO, ComfyExtension, Input
|
|||||||
from comfy_api_nodes.apis.bytedance import (
|
from comfy_api_nodes.apis.bytedance import (
|
||||||
RECOMMENDED_PRESETS,
|
RECOMMENDED_PRESETS,
|
||||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||||
|
SEEDANCE2_PRICE_PER_1K_TOKENS,
|
||||||
|
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
|
||||||
VIDEO_TASKS_EXECUTION_TIME,
|
VIDEO_TASKS_EXECUTION_TIME,
|
||||||
Image2VideoTaskCreationRequest,
|
Image2VideoTaskCreationRequest,
|
||||||
ImageTaskCreationResponse,
|
ImageTaskCreationResponse,
|
||||||
|
Seedance2TaskCreationRequest,
|
||||||
Seedream4Options,
|
Seedream4Options,
|
||||||
Seedream4TaskCreationRequest,
|
Seedream4TaskCreationRequest,
|
||||||
|
TaskAudioContent,
|
||||||
|
TaskAudioContentUrl,
|
||||||
TaskCreationResponse,
|
TaskCreationResponse,
|
||||||
TaskImageContent,
|
TaskImageContent,
|
||||||
TaskImageContentUrl,
|
TaskImageContentUrl,
|
||||||
TaskStatusResponse,
|
TaskStatusResponse,
|
||||||
TaskTextContent,
|
TaskTextContent,
|
||||||
|
TaskVideoContent,
|
||||||
|
TaskVideoContentUrl,
|
||||||
Text2ImageTaskCreationRequest,
|
Text2ImageTaskCreationRequest,
|
||||||
Text2VideoTaskCreationRequest,
|
Text2VideoTaskCreationRequest,
|
||||||
)
|
)
|
||||||
@ -29,7 +36,10 @@ from comfy_api_nodes.util import (
|
|||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
|
upload_audio_to_comfyapi,
|
||||||
|
upload_image_to_comfyapi,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_string,
|
validate_string,
|
||||||
@ -46,12 +56,56 @@ SEEDREAM_MODELS = {
|
|||||||
# Long-running tasks endpoints(e.g., video)
|
# Long-running tasks endpoints(e.g., video)
|
||||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||||
|
BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/contents/generations/tasks" # + /{task_id}
|
||||||
|
|
||||||
|
SEEDANCE_MODELS = {
|
||||||
|
"Seedance 2.0": "dreamina-seedance-2-0-260128",
|
||||||
|
"Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128",
|
||||||
|
}
|
||||||
|
|
||||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
|
||||||
|
"""Validate reference video pixel count against Seedance 2.0 model limits."""
|
||||||
|
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||||
|
if not limits:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
w, h = video.get_dimensions()
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
pixels = w * h
|
||||||
|
min_px = limits.get("min")
|
||||||
|
max_px = limits.get("max")
|
||||||
|
if min_px and pixels < min_px:
|
||||||
|
raise ValueError(
|
||||||
|
f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model."
|
||||||
|
)
|
||||||
|
if max_px and pixels > max_px:
|
||||||
|
raise ValueError(
|
||||||
|
f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. "
|
||||||
|
f"Maximum is {max_px:,}px for this model. Try downscaling the video."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||||
|
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||||
|
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||||
|
if rate is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extractor(response: TaskStatusResponse) -> float | None:
|
||||||
|
if response.usage is None:
|
||||||
|
return None
|
||||||
|
return response.usage.total_tokens * 1.43 * rate / 1_000.0
|
||||||
|
|
||||||
|
return extractor
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||||
if response.error:
|
if response.error:
|
||||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||||
@ -335,8 +389,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
mp_provided = out_num_pixels / 1_000_000.0
|
mp_provided = out_num_pixels / 1_000_000.0
|
||||||
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Minimum image resolution for the selected model is 3.68MP, "
|
f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided."
|
||||||
f"but {mp_provided:.2f}MP provided."
|
|
||||||
)
|
)
|
||||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -952,33 +1005,6 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_video_task(
|
|
||||||
cls: type[IO.ComfyNode],
|
|
||||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
|
||||||
estimated_duration: int | None,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
if payload.model in DEPRECATED_MODELS:
|
|
||||||
logger.warning(
|
|
||||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
|
||||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
|
||||||
payload.model,
|
|
||||||
)
|
|
||||||
initial_response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
|
||||||
data=payload,
|
|
||||||
response_model=TaskCreationResponse,
|
|
||||||
)
|
|
||||||
response = await poll_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
|
||||||
status_extractor=lambda r: r.status,
|
|
||||||
estimated_duration=estimated_duration,
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
|
||||||
|
|
||||||
|
|
||||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||||
for i in text_params:
|
for i in text_params:
|
||||||
if f"--{i} " in prompt:
|
if f"--{i} " in prompt:
|
||||||
@ -1040,6 +1066,542 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _seedance2_text_inputs(resolutions: list[str]):
|
||||||
|
return [
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt for video generation.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=resolutions,
|
||||||
|
tooltip="Resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"ratio",
|
||||||
|
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||||
|
tooltip="Aspect ratio of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=7,
|
||||||
|
min=4,
|
||||||
|
max=15,
|
||||||
|
step=1,
|
||||||
|
tooltip="Duration of the output video in seconds (4-15).",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=True,
|
||||||
|
tooltip="Enable audio generation for the output video.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ByteDance2TextToVideoNode",
|
||||||
|
display_name="ByteDance Seedance 2.0 Text to Video",
|
||||||
|
category="api node/video/ByteDance",
|
||||||
|
description="Generate video using Seedance 2.0 models based on a text prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||||
|
],
|
||||||
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add a watermark to the video.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$rate480 := 10044;
|
||||||
|
$rate720 := 21600;
|
||||||
|
$rate1080 := 48800;
|
||||||
|
$m := widgets.model;
|
||||||
|
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$rate := $res = "1080p" ? $rate1080 :
|
||||||
|
$res = "720p" ? $rate720 :
|
||||||
|
$rate480;
|
||||||
|
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||||
|
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
watermark: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||||
|
model_id = SEEDANCE_MODELS[model["model"]]
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=Seedance2TaskCreationRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=[TaskTextContent(text=model["prompt"])],
|
||||||
|
generate_audio=model["generate_audio"],
|
||||||
|
resolution=model["resolution"],
|
||||||
|
ratio=model["ratio"],
|
||||||
|
duration=model["duration"],
|
||||||
|
seed=seed,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||||
|
poll_interval=9,
|
||||||
|
max_poll_attempts=180,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ByteDance2FirstLastFrameNode",
|
||||||
|
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
|
||||||
|
category="api node/video/ByteDance",
|
||||||
|
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
|
||||||
|
inputs=[
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||||
|
],
|
||||||
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"first_frame",
|
||||||
|
tooltip="First frame image for the video.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"last_frame",
|
||||||
|
tooltip="Last frame image for the video.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add a watermark to the video.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$rate480 := 10044;
|
||||||
|
$rate720 := 21600;
|
||||||
|
$rate1080 := 48800;
|
||||||
|
$m := widgets.model;
|
||||||
|
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$rate := $res = "1080p" ? $rate1080 :
|
||||||
|
$res = "720p" ? $rate720 :
|
||||||
|
$rate480;
|
||||||
|
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||||
|
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: dict,
|
||||||
|
first_frame: Input.Image,
|
||||||
|
seed: int,
|
||||||
|
watermark: bool,
|
||||||
|
last_frame: Input.Image | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||||
|
model_id = SEEDANCE_MODELS[model["model"]]
|
||||||
|
|
||||||
|
content: list[TaskTextContent | TaskImageContent] = [
|
||||||
|
TaskTextContent(text=model["prompt"]),
|
||||||
|
TaskImageContent(
|
||||||
|
image_url=TaskImageContentUrl(
|
||||||
|
url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||||
|
),
|
||||||
|
role="first_frame",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
if last_frame is not None:
|
||||||
|
content.append(
|
||||||
|
TaskImageContent(
|
||||||
|
image_url=TaskImageContentUrl(
|
||||||
|
url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.")
|
||||||
|
),
|
||||||
|
role="last_frame",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=Seedance2TaskCreationRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=content,
|
||||||
|
generate_audio=model["generate_audio"],
|
||||||
|
resolution=model["resolution"],
|
||||||
|
ratio=model["ratio"],
|
||||||
|
duration=model["duration"],
|
||||||
|
seed=seed,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||||
|
poll_interval=9,
|
||||||
|
max_poll_attempts=180,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
def _seedance2_reference_inputs(resolutions: list[str]):
|
||||||
|
return [
|
||||||
|
*_seedance2_text_inputs(resolutions),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Image.Input("reference_image"),
|
||||||
|
names=[
|
||||||
|
"image_1",
|
||||||
|
"image_2",
|
||||||
|
"image_3",
|
||||||
|
"image_4",
|
||||||
|
"image_5",
|
||||||
|
"image_6",
|
||||||
|
"image_7",
|
||||||
|
"image_8",
|
||||||
|
"image_9",
|
||||||
|
],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_videos",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Video.Input("reference_video"),
|
||||||
|
names=["video_1", "video_2", "video_3"],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_audios",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Audio.Input("reference_audio"),
|
||||||
|
names=["audio_1", "audio_2", "audio_3"],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ByteDance2ReferenceNode",
|
||||||
|
display_name="ByteDance Seedance 2.0 Reference to Video",
|
||||||
|
category="api node/video/ByteDance",
|
||||||
|
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
|
||||||
|
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
|
||||||
|
inputs=[
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
|
||||||
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
|
||||||
|
],
|
||||||
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"watermark",
|
||||||
|
default=False,
|
||||||
|
tooltip="Whether to add a watermark to the video.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["model", "model.resolution", "model.duration"],
|
||||||
|
input_groups=["model.reference_videos"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$rate480 := 10044;
|
||||||
|
$rate720 := 21600;
|
||||||
|
$rate1080 := 48800;
|
||||||
|
$m := widgets.model;
|
||||||
|
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
|
||||||
|
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
|
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$rate := $res = "1080p" ? $rate1080 :
|
||||||
|
$res = "720p" ? $rate720 :
|
||||||
|
$rate480;
|
||||||
|
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
|
||||||
|
$minVideoFactor := $ceil($dur * 5 / 3);
|
||||||
|
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
|
||||||
|
$maxVideoCost := (15 + $dur) * $rate * $videoPricePer1K / 1000;
|
||||||
|
$hasVideo
|
||||||
|
? {
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": $minVideoCost,
|
||||||
|
"max_usd": $maxVideoCost,
|
||||||
|
"format": {"approximate": true}
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
"type": "usd",
|
||||||
|
"usd": $noVideoCost,
|
||||||
|
"format": {"approximate": true}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
watermark: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(model["prompt"], strip_whitespace=True, min_length=1)
|
||||||
|
|
||||||
|
reference_images = model.get("reference_images", {})
|
||||||
|
reference_videos = model.get("reference_videos", {})
|
||||||
|
reference_audios = model.get("reference_audios", {})
|
||||||
|
|
||||||
|
if not reference_images and not reference_videos:
|
||||||
|
raise ValueError("At least one reference image or video is required.")
|
||||||
|
|
||||||
|
model_id = SEEDANCE_MODELS[model["model"]]
|
||||||
|
has_video_input = len(reference_videos) > 0
|
||||||
|
total_video_duration = 0.0
|
||||||
|
for i, key in enumerate(reference_videos, 1):
|
||||||
|
video = reference_videos[key]
|
||||||
|
_validate_ref_video_pixels(video, model_id, i)
|
||||||
|
try:
|
||||||
|
dur = video.get_duration()
|
||||||
|
if dur < 1.8:
|
||||||
|
raise ValueError(f"Reference video {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||||
|
total_video_duration += dur
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if total_video_duration > 15.1:
|
||||||
|
raise ValueError(f"Total reference video duration is {total_video_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||||
|
|
||||||
|
total_audio_duration = 0.0
|
||||||
|
for i, key in enumerate(reference_audios, 1):
|
||||||
|
audio = reference_audios[key]
|
||||||
|
dur = int(audio["waveform"].shape[-1]) / int(audio["sample_rate"])
|
||||||
|
if dur < 1.8:
|
||||||
|
raise ValueError(f"Reference audio {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.")
|
||||||
|
total_audio_duration += dur
|
||||||
|
if total_audio_duration > 15.1:
|
||||||
|
raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.")
|
||||||
|
|
||||||
|
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [
|
||||||
|
TaskTextContent(text=model["prompt"]),
|
||||||
|
]
|
||||||
|
for i, key in enumerate(reference_images, 1):
|
||||||
|
content.append(
|
||||||
|
TaskImageContent(
|
||||||
|
image_url=TaskImageContentUrl(
|
||||||
|
url=await upload_image_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
image=reference_images[key],
|
||||||
|
wait_label=f"Uploading image {i}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
role="reference_image",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for i, key in enumerate(reference_videos, 1):
|
||||||
|
content.append(
|
||||||
|
TaskVideoContent(
|
||||||
|
video_url=TaskVideoContentUrl(
|
||||||
|
url=await upload_video_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
reference_videos[key],
|
||||||
|
wait_label=f"Uploading video {i}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for key in reference_audios:
|
||||||
|
content.append(
|
||||||
|
TaskAudioContent(
|
||||||
|
audio_url=TaskAudioContentUrl(
|
||||||
|
url=await upload_audio_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
reference_audios[key],
|
||||||
|
container_format="mp3",
|
||||||
|
codec_name="libmp3lame",
|
||||||
|
mime_type="audio/mpeg",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=Seedance2TaskCreationRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=content,
|
||||||
|
generate_audio=model["generate_audio"],
|
||||||
|
resolution=model["resolution"],
|
||||||
|
ratio=model["ratio"],
|
||||||
|
duration=model["duration"],
|
||||||
|
seed=seed,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
||||||
|
poll_interval=9,
|
||||||
|
max_poll_attempts=180,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
async def process_video_task(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||||
|
estimated_duration: int | None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
if payload.model in DEPRECATED_MODELS:
|
||||||
|
logger.warning(
|
||||||
|
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||||
|
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||||
|
payload.model,
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
data=payload,
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
estimated_duration=estimated_duration,
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
class ByteDanceExtension(ComfyExtension):
|
class ByteDanceExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -1050,6 +1612,9 @@ class ByteDanceExtension(ComfyExtension):
|
|||||||
ByteDanceImageToVideoNode,
|
ByteDanceImageToVideoNode,
|
||||||
ByteDanceFirstLastFrameNode,
|
ByteDanceFirstLastFrameNode,
|
||||||
ByteDanceImageReferenceNode,
|
ByteDanceImageReferenceNode,
|
||||||
|
ByteDance2TextToVideoNode,
|
||||||
|
ByteDance2FirstLastFrameNode,
|
||||||
|
ByteDance2ReferenceNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -558,7 +558,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
|
|||||||
(
|
(
|
||||||
$res := $lookup(widgets, "model.resolution");
|
$res := $lookup(widgets, "model.resolution");
|
||||||
$dur := $lookup(widgets, "model.duration");
|
$dur := $lookup(widgets, "model.duration");
|
||||||
$refs := inputGroups["model.reference_images"];
|
$refs := $lookup(inputGroups, "model.reference_images");
|
||||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||||
{"type":"usd","usd": $price}
|
{"type":"usd","usd": $price}
|
||||||
|
|||||||
@ -221,14 +221,17 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
|
||||||
|
obj_result = None
|
||||||
|
if obj_file_response:
|
||||||
|
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
f"{task_id}.glb",
|
f"{task_id}.glb",
|
||||||
await download_url_to_file_3d(
|
await download_url_to_file_3d(
|
||||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
),
|
),
|
||||||
obj_result.obj,
|
obj_result.obj if obj_result else None,
|
||||||
obj_result.texture,
|
obj_result.texture if obj_result else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -378,17 +381,30 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
|
||||||
|
if obj_file_response:
|
||||||
|
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
|
||||||
|
return IO.NodeOutput(
|
||||||
|
f"{task_id}.glb",
|
||||||
|
await download_url_to_file_3d(
|
||||||
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
|
),
|
||||||
|
obj_result.obj,
|
||||||
|
obj_result.texture,
|
||||||
|
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
|
||||||
|
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
||||||
|
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
||||||
|
)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
f"{task_id}.glb",
|
f"{task_id}.glb",
|
||||||
await download_url_to_file_3d(
|
await download_url_to_file_3d(
|
||||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
),
|
),
|
||||||
obj_result.obj,
|
None,
|
||||||
obj_result.texture,
|
None,
|
||||||
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
|
None,
|
||||||
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
None,
|
||||||
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,44 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
from comfy_extras.nodes_images import SVG
|
from comfy_extras.nodes_images import SVG
|
||||||
|
|
||||||
|
_ARROW_MODELS = ["arrow-1.1", "arrow-1.1-max", "arrow-preview"]
|
||||||
|
|
||||||
|
|
||||||
|
def _arrow_sampling_inputs():
|
||||||
|
"""Shared sampling inputs for all Arrow model variants."""
|
||||||
|
return [
|
||||||
|
IO.Float.Input(
|
||||||
|
"temperature",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Randomness control. Higher values increase randomness.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"top_p",
|
||||||
|
default=1.0,
|
||||||
|
min=0.05,
|
||||||
|
max=1.0,
|
||||||
|
step=0.05,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Nucleus sampling parameter.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"presence_penalty",
|
||||||
|
default=0.0,
|
||||||
|
min=-2.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Token presence penalty.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class QuiverTextToSVGNode(IO.ComfyNode):
|
class QuiverTextToSVGNode(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -39,6 +77,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
|||||||
default="",
|
default="",
|
||||||
tooltip="Additional style or formatting guidance.",
|
tooltip="Additional style or formatting guidance.",
|
||||||
optional=True,
|
optional=True,
|
||||||
|
advanced=True,
|
||||||
),
|
),
|
||||||
IO.Autogrow.Input(
|
IO.Autogrow.Input(
|
||||||
"reference_images",
|
"reference_images",
|
||||||
@ -53,43 +92,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[IO.DynamicCombo.Option(m, _arrow_sampling_inputs()) for m in _ARROW_MODELS],
|
||||||
IO.DynamicCombo.Option(
|
|
||||||
"arrow-preview",
|
|
||||||
[
|
|
||||||
IO.Float.Input(
|
|
||||||
"temperature",
|
|
||||||
default=1.0,
|
|
||||||
min=0.0,
|
|
||||||
max=2.0,
|
|
||||||
step=0.1,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Randomness control. Higher values increase randomness.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"top_p",
|
|
||||||
default=1.0,
|
|
||||||
min=0.05,
|
|
||||||
max=1.0,
|
|
||||||
step=0.05,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Nucleus sampling parameter.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"presence_penalty",
|
|
||||||
default=0.0,
|
|
||||||
min=-2.0,
|
|
||||||
max=2.0,
|
|
||||||
step=0.1,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Token presence penalty.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
tooltip="Model to use for SVG generation.",
|
tooltip="Model to use for SVG generation.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -112,7 +115,16 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.429}""",
|
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$contains(widgets.model, "max")
|
||||||
|
? {"type":"usd","usd":0.3575}
|
||||||
|
: $contains(widgets.model, "preview")
|
||||||
|
? {"type":"usd","usd":0.429}
|
||||||
|
: {"type":"usd","usd":0.286}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -176,12 +188,13 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
|||||||
"auto_crop",
|
"auto_crop",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Automatically crop to the dominant subject.",
|
tooltip="Automatically crop to the dominant subject.",
|
||||||
|
advanced=True,
|
||||||
),
|
),
|
||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option(
|
IO.DynamicCombo.Option(
|
||||||
"arrow-preview",
|
m,
|
||||||
[
|
[
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"target_size",
|
"target_size",
|
||||||
@ -189,39 +202,12 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
|||||||
min=128,
|
min=128,
|
||||||
max=4096,
|
max=4096,
|
||||||
tooltip="Square resize target in pixels.",
|
tooltip="Square resize target in pixels.",
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"temperature",
|
|
||||||
default=1.0,
|
|
||||||
min=0.0,
|
|
||||||
max=2.0,
|
|
||||||
step=0.1,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Randomness control. Higher values increase randomness.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"top_p",
|
|
||||||
default=1.0,
|
|
||||||
min=0.05,
|
|
||||||
max=1.0,
|
|
||||||
step=0.05,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Nucleus sampling parameter.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"presence_penalty",
|
|
||||||
default=0.0,
|
|
||||||
min=-2.0,
|
|
||||||
max=2.0,
|
|
||||||
step=0.1,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Token presence penalty.",
|
|
||||||
advanced=True,
|
advanced=True,
|
||||||
),
|
),
|
||||||
|
*_arrow_sampling_inputs(),
|
||||||
],
|
],
|
||||||
),
|
)
|
||||||
|
for m in _ARROW_MODELS
|
||||||
],
|
],
|
||||||
tooltip="Model to use for SVG vectorization.",
|
tooltip="Model to use for SVG vectorization.",
|
||||||
),
|
),
|
||||||
@ -245,7 +231,16 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.429}""",
|
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$contains(widgets.model, "max")
|
||||||
|
? {"type":"usd","usd":0.3575}
|
||||||
|
: $contains(widgets.model, "preview")
|
||||||
|
? {"type":"usd","usd":0.429}
|
||||||
|
: {"type":"usd","usd":0.286}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
287
comfy_api_nodes/nodes_sonilo.py
Normal file
287
comfy_api_nodes/nodes_sonilo.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
ApiEndpoint,
|
||||||
|
audio_bytes_to_audio_input,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.util._helpers import (
|
||||||
|
default_base_url,
|
||||||
|
get_auth_header,
|
||||||
|
get_node_id,
|
||||||
|
is_processing_interrupted,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.util.common_exceptions import ProcessingInterrupted
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SoniloVideoToMusic(IO.ComfyNode):
|
||||||
|
"""Generate music from video using Sonilo's AI model."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="SoniloVideoToMusic",
|
||||||
|
display_name="Sonilo Video to Music",
|
||||||
|
category="api node/audio/Sonilo",
|
||||||
|
description="Generate music from video content using Sonilo's AI model. "
|
||||||
|
"Analyzes the video and creates matching music.",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input(
|
||||||
|
"video",
|
||||||
|
tooltip="Input video to generate music from. Maximum duration: 6 minutes.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
default="",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Optional text prompt to guide music generation. "
|
||||||
|
"Leave empty for best quality - the model will fully analyze the video content.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
|
||||||
|
"service but kept for graph consistency.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr='{"type":"usd","usd":0.009,"format":{"suffix":"/second"}}',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: Input.Video,
|
||||||
|
prompt: str = "",
|
||||||
|
seed: int = 0,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
video_url = await upload_video_to_comfyapi(cls, video, max_duration=360)
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field("video_url", video_url)
|
||||||
|
if prompt.strip():
|
||||||
|
form.add_field("prompt", prompt.strip())
|
||||||
|
audio_bytes = await _stream_sonilo_music(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/sonilo/v2m/generate", method="POST"),
|
||||||
|
form,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
|
||||||
|
|
||||||
|
|
||||||
|
class SoniloTextToMusic(IO.ComfyNode):
|
||||||
|
"""Generate music from a text prompt using Sonilo's AI model."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="SoniloTextToMusic",
|
||||||
|
display_name="Sonilo Text to Music",
|
||||||
|
category="api node/audio/Sonilo",
|
||||||
|
description="Generate music from a text prompt using Sonilo's AI model. "
|
||||||
|
"Leave duration at 0 to let the model infer it from the prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
default="",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text prompt describing the music to generate.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=360,
|
||||||
|
tooltip="Target duration in seconds. Set to 0 to let the model "
|
||||||
|
"infer the duration from the prompt. Maximum: 6 minutes.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
|
||||||
|
"service but kept for graph consistency.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Audio.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
widgets.duration > 0
|
||||||
|
? {"type":"usd","usd": 0.005 * widgets.duration}
|
||||||
|
: {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
duration: int = 0,
|
||||||
|
seed: int = 0,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field("prompt", prompt)
|
||||||
|
if duration > 0:
|
||||||
|
form.add_field("duration", str(duration))
|
||||||
|
audio_bytes = await _stream_sonilo_music(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
|
||||||
|
form,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_sonilo_music(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
endpoint: ApiEndpoint,
|
||||||
|
form: aiohttp.FormData,
|
||||||
|
) -> bytes:
|
||||||
|
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
||||||
|
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
||||||
|
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
headers.update(get_auth_header(cls))
|
||||||
|
headers.update(endpoint.headers)
|
||||||
|
|
||||||
|
node_id = get_node_id(cls)
|
||||||
|
start_ts = time.monotonic()
|
||||||
|
last_chunk_status_ts = 0.0
|
||||||
|
audio_streams: dict[int, list[bytes]] = {}
|
||||||
|
title: str | None = None
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=1200.0, sock_read=300.0)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
PromptServer.instance.send_progress_text("Status: Queued", node_id)
|
||||||
|
async with session.post(url, data=form, headers=headers) as resp:
|
||||||
|
if resp.status >= 400:
|
||||||
|
msg = await _extract_error_message(resp)
|
||||||
|
raise Exception(f"Sonilo API error ({resp.status}): {msg}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if is_processing_interrupted():
|
||||||
|
raise ProcessingInterrupted("Task cancelled")
|
||||||
|
|
||||||
|
raw_line = await resp.content.readline()
|
||||||
|
if not raw_line:
|
||||||
|
break
|
||||||
|
|
||||||
|
line = raw_line.decode("utf-8").strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
evt = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Sonilo: skipping malformed NDJSON line")
|
||||||
|
continue
|
||||||
|
|
||||||
|
evt_type = evt.get("type")
|
||||||
|
if evt_type == "error":
|
||||||
|
code = evt.get("code", "UNKNOWN")
|
||||||
|
message = evt.get("message", "Unknown error")
|
||||||
|
raise Exception(f"Sonilo generation error ({code}): {message}")
|
||||||
|
if evt_type == "duration":
|
||||||
|
duration_sec = evt.get("duration_sec")
|
||||||
|
if duration_sec is not None:
|
||||||
|
PromptServer.instance.send_progress_text(
|
||||||
|
f"Status: Generating\nVideo duration: {duration_sec:.1f}s",
|
||||||
|
node_id,
|
||||||
|
)
|
||||||
|
elif evt_type in ("titles", "title"):
|
||||||
|
# v2m sends a "titles" list, t2m sends a scalar "title"
|
||||||
|
if evt_type == "titles":
|
||||||
|
titles = evt.get("titles", [])
|
||||||
|
if titles:
|
||||||
|
title = titles[0]
|
||||||
|
else:
|
||||||
|
title = evt.get("title") or title
|
||||||
|
if title:
|
||||||
|
PromptServer.instance.send_progress_text(
|
||||||
|
f"Status: Generating\nTitle: {title}",
|
||||||
|
node_id,
|
||||||
|
)
|
||||||
|
elif evt_type == "audio_chunk":
|
||||||
|
stream_idx = evt.get("stream_index", 0)
|
||||||
|
chunk_data = base64.b64decode(evt["data"])
|
||||||
|
|
||||||
|
if stream_idx not in audio_streams:
|
||||||
|
audio_streams[stream_idx] = []
|
||||||
|
audio_streams[stream_idx].append(chunk_data)
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - last_chunk_status_ts >= 1.0:
|
||||||
|
total_chunks = sum(len(chunks) for chunks in audio_streams.values())
|
||||||
|
elapsed = int(now - start_ts)
|
||||||
|
status_lines = ["Status: Receiving audio"]
|
||||||
|
if title:
|
||||||
|
status_lines.append(f"Title: {title}")
|
||||||
|
status_lines.append(f"Chunks received: {total_chunks}")
|
||||||
|
status_lines.append(f"Time elapsed: {elapsed}s")
|
||||||
|
PromptServer.instance.send_progress_text("\n".join(status_lines), node_id)
|
||||||
|
last_chunk_status_ts = now
|
||||||
|
elif evt_type == "complete":
|
||||||
|
break
|
||||||
|
|
||||||
|
if not audio_streams:
|
||||||
|
raise Exception("Sonilo API returned no audio data.")
|
||||||
|
|
||||||
|
PromptServer.instance.send_progress_text("Status: Completed", node_id)
|
||||||
|
selected_stream = 0 if 0 in audio_streams else min(audio_streams)
|
||||||
|
return b"".join(audio_streams[selected_stream])
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_error_message(resp: aiohttp.ClientResponse) -> str:
|
||||||
|
"""Extract a human-readable error message from an HTTP error response."""
|
||||||
|
try:
|
||||||
|
error_body = await resp.json()
|
||||||
|
detail = error_body.get("detail", {})
|
||||||
|
if isinstance(detail, dict):
|
||||||
|
return detail.get("message", str(detail))
|
||||||
|
return str(detail)
|
||||||
|
except Exception:
|
||||||
|
return await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
class SoniloExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [SoniloVideoToMusic, SoniloTextToMusic]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> SoniloExtension:
|
||||||
|
return SoniloExtension()
|
||||||
@ -401,7 +401,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.25}""",
|
expr="""{"type":"usd","usd":0.4}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -510,7 +510,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.25}""",
|
expr="""{"type":"usd","usd":0.6}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -593,7 +593,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.01}""",
|
expr="""{"type":"usd","usd":0.02}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,10 @@ import comfy.model_management
|
|||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.ldm.lumina.controlnet
|
import comfy.ldm.lumina.controlnet
|
||||||
|
import comfy.ldm.supir.supir_modules
|
||||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||||
|
from comfy_api.latest import io
|
||||||
|
from comfy.ldm.supir.supir_patch import SUPIRPatch
|
||||||
|
|
||||||
|
|
||||||
class BlockWiseControlBlock(torch.nn.Module):
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
@ -266,6 +269,27 @@ class ModelPatchLoader:
|
|||||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||||
device=comfy.model_management.unet_offload_device(),
|
device=comfy.model_management.unet_offload_device(),
|
||||||
operations=comfy.ops.manual_cast)
|
operations=comfy.ops.manual_cast)
|
||||||
|
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace = {}
|
||||||
|
if 'model.control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace["model.control_model."] = "control_model."
|
||||||
|
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||||
|
else:
|
||||||
|
prefix_replace["control_model."] = "control_model."
|
||||||
|
prefix_replace["project_modules."] = "project_modules."
|
||||||
|
|
||||||
|
# Extract denoise_encoder weights before filter_keys discards them
|
||||||
|
de_prefix = "first_stage_model.denoise_encoder."
|
||||||
|
denoise_encoder_sd = {}
|
||||||
|
for k in list(sd.keys()):
|
||||||
|
if k.startswith(de_prefix):
|
||||||
|
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
|
||||||
|
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||||
|
sd.pop("control_model.mask_LQ", None)
|
||||||
|
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
if denoise_encoder_sd:
|
||||||
|
model.denoise_encoder_sd = denoise_encoder_sd
|
||||||
|
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||||
@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRApply(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SUPIRApply",
|
||||||
|
category="model_patches/supir",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.ModelPatch.Input("model_patch"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||||
|
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
|
||||||
|
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
|
||||||
|
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
|
||||||
|
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||||
|
tooltip="Sigma threshold below which restore_cfg is disabled."),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
|
||||||
|
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
|
||||||
|
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
|
||||||
|
if not denoise_sd:
|
||||||
|
return vae.encode(image)
|
||||||
|
|
||||||
|
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
|
||||||
|
orig_patcher = vae.patcher
|
||||||
|
vae.patcher = orig_patcher.clone()
|
||||||
|
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
|
||||||
|
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
|
||||||
|
try:
|
||||||
|
return vae.encode(image)
|
||||||
|
finally:
|
||||||
|
vae.patcher = orig_patcher
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
|
||||||
|
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||||
|
model_patched = model.clone()
|
||||||
|
hint_latent = model.get_model_object("latent_format").process_in(
|
||||||
|
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
|
||||||
|
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||||
|
patch.register(model_patched)
|
||||||
|
|
||||||
|
if restore_cfg > 0.0:
|
||||||
|
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
|
||||||
|
latent_format = model.get_model_object("latent_format")
|
||||||
|
decoded = vae.decode(latent_format.process_out(hint_latent))
|
||||||
|
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
|
||||||
|
sigma_max = 14.6146
|
||||||
|
|
||||||
|
def restore_cfg_function(args):
|
||||||
|
denoised = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
if sigma.dim() > 0:
|
||||||
|
s = sigma[0].item()
|
||||||
|
else:
|
||||||
|
s = sigma.item()
|
||||||
|
if s > restore_cfg_s_tmin:
|
||||||
|
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
|
||||||
|
b = denoised.shape[0]
|
||||||
|
if ref.shape[0] != b:
|
||||||
|
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
|
||||||
|
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
|
||||||
|
d_center = denoised - ref
|
||||||
|
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
|
||||||
|
|
||||||
|
return io.NodeOutput(model_patched)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
"ZImageFunControlnet": ZImageFunControlnet,
|
"ZImageFunControlnet": ZImageFunControlnet,
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
|
"SUPIRApply": SUPIRApply,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from PIL import Image
|
|||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypedDict, Literal
|
from typing import TypedDict, Literal
|
||||||
|
import kornia
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(batched)
|
return io.NodeOutput(batched)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorTransfer(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ColorTransfer",
|
||||||
|
category="image/postprocessing",
|
||||||
|
description="Match the colors of one image to another using various algorithms.",
|
||||||
|
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||||
|
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||||
|
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||||
|
io.DynamicCombo.Input("source_stats",
|
||||||
|
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||||
|
options=[
|
||||||
|
io.DynamicCombo.Option("per_frame", []),
|
||||||
|
io.DynamicCombo.Option("uniform", []),
|
||||||
|
io.DynamicCombo.Option("target_frame", [
|
||||||
|
io.Int.Input("target_index", default=0, min=0, max=10000,
|
||||||
|
tooltip="Frame index used as the source baseline for computing the transform to image_ref"),
|
||||||
|
]),
|
||||||
|
]),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(display_name="image"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_lab(images, i, device):
|
||||||
|
return kornia.color.rgb_to_lab(
|
||||||
|
images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pool_stats(images, device, is_reinhard, eps):
|
||||||
|
"""Two-pass pooled mean + std/cov across all frames."""
|
||||||
|
N, C = images.shape[0], images.shape[3]
|
||||||
|
HW = images.shape[1] * images.shape[2]
|
||||||
|
mean = torch.zeros(C, 1, device=device, dtype=torch.float32)
|
||||||
|
for i in range(N):
|
||||||
|
mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True)
|
||||||
|
mean /= N
|
||||||
|
acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32)
|
||||||
|
for i in range(N):
|
||||||
|
centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean
|
||||||
|
if is_reinhard:
|
||||||
|
acc += (centered * centered).mean(dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
acc += centered @ centered.T / HW
|
||||||
|
if is_reinhard:
|
||||||
|
return mean, torch.sqrt(acc / N).clamp_min_(eps)
|
||||||
|
return mean, acc / N
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _frame_stats(lab_flat, hw, is_reinhard, eps):
|
||||||
|
"""Per-frame mean + std/cov."""
|
||||||
|
mean = lab_flat.mean(dim=-1, keepdim=True)
|
||||||
|
if is_reinhard:
|
||||||
|
return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps)
|
||||||
|
centered = lab_flat - mean
|
||||||
|
return mean, centered @ centered.T / hw
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mkl_matrix(cov_s, cov_r, eps):
|
||||||
|
"""Compute MKL 3x3 transform matrix from source and ref covariances."""
|
||||||
|
eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s)
|
||||||
|
sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps)
|
||||||
|
|
||||||
|
scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0)
|
||||||
|
mid = scaled_V.T @ cov_r @ scaled_V
|
||||||
|
eig_val_m, eig_vec_m = torch.linalg.eigh(mid)
|
||||||
|
sqrt_m = torch.sqrt(eig_val_m.clamp_min(0))
|
||||||
|
|
||||||
|
inv_sqrt_s = 1.0 / sqrt_val_s
|
||||||
|
inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0)
|
||||||
|
M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T
|
||||||
|
return inv_scaled_V @ M_half @ inv_scaled_V.T
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _histogram_lut(src, ref, bins=256):
|
||||||
|
"""Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1]."""
|
||||||
|
s_bins = (src * (bins - 1)).long().clamp(0, bins - 1)
|
||||||
|
r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1)
|
||||||
|
s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||||
|
r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||||
|
ones_s = torch.ones_like(src)
|
||||||
|
ones_r = torch.ones_like(ref)
|
||||||
|
s_hist.scatter_add_(1, s_bins, ones_s)
|
||||||
|
r_hist.scatter_add_(1, r_bins, ones_r)
|
||||||
|
s_cdf = s_hist.cumsum(1)
|
||||||
|
s_cdf = s_cdf / s_cdf[:, -1:]
|
||||||
|
r_cdf = r_hist.cumsum(1)
|
||||||
|
r_cdf = r_cdf / r_cdf[:, -1:]
|
||||||
|
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _pooled_cdf(cls, images, device, num_bins=256):
|
||||||
|
"""Build pooled CDF across all frames, one frame at a time."""
|
||||||
|
C = images.shape[3]
|
||||||
|
hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32)
|
||||||
|
for i in range(images.shape[0]):
|
||||||
|
frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||||
|
bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1)
|
||||||
|
hist.scatter_add_(1, bins, torch.ones_like(frame))
|
||||||
|
cdf = hist.cumsum(1)
|
||||||
|
return cdf / cdf[:, -1:]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B):
|
||||||
|
"""Build per-frame or uniform LUT transform for histogram mode."""
|
||||||
|
if stats_mode == 'per_frame':
|
||||||
|
return None # LUT computed per-frame in the apply loop
|
||||||
|
|
||||||
|
r_cdf = cls._pooled_cdf(image_ref, device)
|
||||||
|
if stats_mode == 'target_frame':
|
||||||
|
ti = min(target_index, B - 1)
|
||||||
|
s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device)
|
||||||
|
else:
|
||||||
|
s_cdf = cls._pooled_cdf(image_target, device)
|
||||||
|
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard):
|
||||||
|
"""Build transform parameters for Lab-based methods. Returns a transform function."""
|
||||||
|
eps = 1e-6
|
||||||
|
B, H, W, C = image_target.shape
|
||||||
|
B_ref = image_ref.shape[0]
|
||||||
|
single_ref = B_ref == 1
|
||||||
|
HW = H * W
|
||||||
|
HW_ref = image_ref.shape[1] * image_ref.shape[2]
|
||||||
|
|
||||||
|
# Precompute ref stats
|
||||||
|
if single_ref or stats_mode in ('uniform', 'target_frame'):
|
||||||
|
ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps)
|
||||||
|
|
||||||
|
# Uniform/target_frame: precompute single affine transform
|
||||||
|
if stats_mode in ('uniform', 'target_frame'):
|
||||||
|
if stats_mode == 'target_frame':
|
||||||
|
ti = min(target_index, B - 1)
|
||||||
|
s_lab = cls._to_lab(image_target, ti, device).view(C, -1)
|
||||||
|
s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps)
|
||||||
|
else:
|
||||||
|
s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps)
|
||||||
|
|
||||||
|
if is_reinhard:
|
||||||
|
scale = ref_sc / s_sc
|
||||||
|
offset = ref_mean - scale * s_mean
|
||||||
|
return lambda src_flat, **_: src_flat * scale + offset
|
||||||
|
T = cls._mkl_matrix(s_sc, ref_sc, eps)
|
||||||
|
offset = ref_mean - T @ s_mean
|
||||||
|
return lambda src_flat, **_: T @ src_flat + offset
|
||||||
|
|
||||||
|
# per_frame
|
||||||
|
def per_frame_transform(src_flat, frame_idx):
|
||||||
|
s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps)
|
||||||
|
|
||||||
|
if single_ref:
|
||||||
|
r_mean, r_sc = ref_mean, ref_sc
|
||||||
|
else:
|
||||||
|
ri = min(frame_idx, B_ref - 1)
|
||||||
|
r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps)
|
||||||
|
|
||||||
|
centered = src_flat - s_mean
|
||||||
|
if is_reinhard:
|
||||||
|
return centered * (r_sc / s_sc) + r_mean
|
||||||
|
T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps)
|
||||||
|
return T @ centered + r_mean
|
||||||
|
|
||||||
|
return per_frame_transform
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput:
|
||||||
|
stats_mode = source_stats["source_stats"]
|
||||||
|
target_index = source_stats.get("target_index", 0)
|
||||||
|
|
||||||
|
if strength == 0 or image_ref is None:
|
||||||
|
return io.NodeOutput(image_target)
|
||||||
|
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
intermediate_device = comfy.model_management.intermediate_device()
|
||||||
|
intermediate_dtype = comfy.model_management.intermediate_dtype()
|
||||||
|
|
||||||
|
B, H, W, C = image_target.shape
|
||||||
|
B_ref = image_ref.shape[0]
|
||||||
|
pbar = comfy.utils.ProgressBar(B)
|
||||||
|
out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
if method == 'histogram':
|
||||||
|
uniform_lut = cls._build_histogram_transform(
|
||||||
|
image_target, image_ref, device, stats_mode, target_index, B)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1)
|
||||||
|
src_flat = src.reshape(C, -1)
|
||||||
|
if uniform_lut is not None:
|
||||||
|
lut = uniform_lut
|
||||||
|
else:
|
||||||
|
ri = min(i, B_ref - 1)
|
||||||
|
ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||||
|
lut = cls._histogram_lut(src_flat, ref)
|
||||||
|
bin_idx = (src_flat * 255).long().clamp(0, 255)
|
||||||
|
matched = lut.gather(1, bin_idx).view(C, H, W)
|
||||||
|
result = matched if strength == 1.0 else torch.lerp(src, matched, strength)
|
||||||
|
out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
pbar.update(1)
|
||||||
|
else:
|
||||||
|
transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab")
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
src_frame = cls._to_lab(image_target, i, device)
|
||||||
|
corrected = transform(src_frame.view(C, -1), frame_idx=i)
|
||||||
|
if strength == 1.0:
|
||||||
|
result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W))
|
||||||
|
else:
|
||||||
|
result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength))
|
||||||
|
out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
return io.NodeOutput(out)
|
||||||
|
|
||||||
|
|
||||||
class PostProcessingExtension(ComfyExtension):
|
class PostProcessingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension):
|
|||||||
BatchImagesNode,
|
BatchImagesNode,
|
||||||
BatchMasksNode,
|
BatchMasksNode,
|
||||||
BatchLatentsNode,
|
BatchLatentsNode,
|
||||||
|
ColorTransfer,
|
||||||
# BatchImagesMasksLatentsNode,
|
# BatchImagesMasksLatentsNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class PreviewAny():
|
|||||||
"required": {"source": (IO.ANY, {})},
|
"required": {"source": (IO.ANY, {})},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ()
|
RETURN_TYPES = (IO.STRING,)
|
||||||
FUNCTION = "main"
|
FUNCTION = "main"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class PreviewAny():
|
|||||||
except Exception:
|
except Exception:
|
||||||
value = 'source exists, but could not be serialized.'
|
value = 'source exists, but could not be serialized.'
|
||||||
|
|
||||||
return {"ui": {"text": (value,)}}
|
return {"ui": {"text": (value,)}, "result": (value,)}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"PreviewAny": PreviewAny,
|
"PreviewAny": PreviewAny,
|
||||||
|
|||||||
@ -32,10 +32,12 @@ class RTDETR_detect(io.ComfyNode):
|
|||||||
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
||||||
B, H, W, C = image.shape
|
B, H, W, C = image.shape
|
||||||
|
|
||||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
|
||||||
|
|
||||||
comfy.model_management.load_model_gpu(model)
|
comfy.model_management.load_model_gpu(model)
|
||||||
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
|
results = []
|
||||||
|
for i in range(0, B, 32):
|
||||||
|
batch = image[i:i + 32]
|
||||||
|
image_in = comfy.utils.common_upscale(batch.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||||
|
results.extend(model.model.diffusion_model(image_in, (W, H)))
|
||||||
|
|
||||||
all_bbox_dicts = []
|
all_bbox_dicts = []
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import colorsys
|
import colorsys
|
||||||
@ -410,7 +411,9 @@ class SDPoseDrawKeypoints(io.ComfyNode):
|
|||||||
pose_outputs.append(canvas)
|
pose_outputs.append(canvas)
|
||||||
|
|
||||||
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
|
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
|
||||||
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
|
final_pose_output = torch.from_numpy(pose_outputs_np).to(
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
dtype=comfy.model_management.intermediate_dtype()) / 255.0
|
||||||
return io.NodeOutput(final_pose_output)
|
return io.NodeOutput(final_pose_output)
|
||||||
|
|
||||||
class SDPoseKeypointExtractor(io.ComfyNode):
|
class SDPoseKeypointExtractor(io.ComfyNode):
|
||||||
@ -459,6 +462,27 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
|||||||
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
|
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
|
||||||
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
|
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
|
||||||
|
|
||||||
|
def _resize_to_model(imgs):
|
||||||
|
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
|
||||||
|
h, w = imgs.shape[-3], imgs.shape[-2]
|
||||||
|
scale = min(model_h / h, model_w / w)
|
||||||
|
sh, sw = int(round(h * scale)), int(round(w * scale))
|
||||||
|
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
|
||||||
|
chw = imgs.permute(0, 3, 1, 2).float()
|
||||||
|
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
|
||||||
|
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||||
|
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
|
||||||
|
return padded.permute(0, 2, 3, 1), scale, pt, pl
|
||||||
|
|
||||||
|
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
|
||||||
|
"""Remap keypoints from model space back to original image space."""
|
||||||
|
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||||
|
invalid = kp[..., 0] < 0
|
||||||
|
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
|
||||||
|
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
|
||||||
|
kp[invalid] = -1
|
||||||
|
return kp
|
||||||
|
|
||||||
def _run_on_latent(latent_batch):
|
def _run_on_latent(latent_batch):
|
||||||
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
|
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
|
||||||
nonlocal captured_feat
|
nonlocal captured_feat
|
||||||
@ -504,36 +528,19 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
|||||||
if x2 <= x1 or y2 <= y1:
|
if x2 <= x1 or y2 <= y1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
crop_h_px, crop_w_px = y2 - y1, x2 - x1
|
|
||||||
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
|
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
|
||||||
|
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
|
||||||
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
|
|
||||||
scale = min(model_h / crop_h_px, model_w / crop_w_px)
|
|
||||||
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
|
|
||||||
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
|
|
||||||
|
|
||||||
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
|
|
||||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
|
||||||
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
|
||||||
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
|
||||||
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
|
|
||||||
|
|
||||||
latent_crop = vae.encode(crop_resized)
|
latent_crop = vae.encode(crop_resized)
|
||||||
kp_batch, sc_batch = _run_on_latent(latent_crop)
|
kp_batch, sc_batch = _run_on_latent(latent_crop)
|
||||||
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
|
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
|
||||||
|
|
||||||
# remove padding offset, undo scale, offset to full-image coordinates.
|
|
||||||
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
|
||||||
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
|
|
||||||
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
|
|
||||||
|
|
||||||
img_keypoints.append(kp)
|
img_keypoints.append(kp)
|
||||||
img_scores.append(sc)
|
img_scores.append(sc_batch[0])
|
||||||
else:
|
else:
|
||||||
# No bboxes for this image – run on the full image
|
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
|
||||||
latent_img = vae.encode(img)
|
latent_img = vae.encode(img_resized)
|
||||||
kp_batch, sc_batch = _run_on_latent(latent_img)
|
kp_batch, sc_batch = _run_on_latent(latent_img)
|
||||||
img_keypoints.append(kp_batch[0])
|
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
|
||||||
img_scores.append(sc_batch[0])
|
img_scores.append(sc_batch[0])
|
||||||
|
|
||||||
all_keypoints.append(img_keypoints)
|
all_keypoints.append(img_keypoints)
|
||||||
@ -541,19 +548,16 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
else: # full-image mode, batched
|
else: # full-image mode, batched
|
||||||
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
|
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
|
||||||
for batch_start in range(0, total_images, batch_size):
|
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
|
||||||
batch_end = min(batch_start + batch_size, total_images)
|
latent_batch = vae.encode(batch_resized)
|
||||||
latent_batch = vae.encode(image[batch_start:batch_end])
|
|
||||||
|
|
||||||
kp_batch, sc_batch = _run_on_latent(latent_batch)
|
kp_batch, sc_batch = _run_on_latent(latent_batch)
|
||||||
|
|
||||||
for kp, sc in zip(kp_batch, sc_batch):
|
for kp, sc in zip(kp_batch, sc_batch):
|
||||||
all_keypoints.append([kp])
|
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
|
||||||
all_scores.append([sc])
|
all_scores.append([sc])
|
||||||
tqdm_pbar.update(1)
|
|
||||||
|
|
||||||
pbar.update(batch_end - batch_start)
|
pbar.update(len(kp_batch))
|
||||||
|
|
||||||
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
|
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
|
||||||
return io.NodeOutput(openpose_frames)
|
return io.NodeOutput(openpose_frames)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import json
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
@ -375,6 +376,39 @@ class RegexReplace(io.ComfyNode):
|
|||||||
return io.NodeOutput(result)
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonExtractString(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="JsonExtractString",
|
||||||
|
display_name="Extract String from JSON",
|
||||||
|
category="utils/string",
|
||||||
|
search_aliases=["json", "extract json", "parse json", "json value", "read json"],
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("json_string", multiline=True),
|
||||||
|
io.String.Input("key", multiline=False),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, json_string, key):
|
||||||
|
try:
|
||||||
|
data = json.loads(json_string)
|
||||||
|
if isinstance(data, dict) and key in data:
|
||||||
|
value = data[key]
|
||||||
|
if value is None:
|
||||||
|
return io.NodeOutput("")
|
||||||
|
|
||||||
|
return io.NodeOutput(str(value))
|
||||||
|
|
||||||
|
return io.NodeOutput("")
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return io.NodeOutput("")
|
||||||
|
|
||||||
class StringExtension(ComfyExtension):
|
class StringExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -390,6 +424,7 @@ class StringExtension(ComfyExtension):
|
|||||||
RegexMatch,
|
RegexMatch,
|
||||||
RegexExtract,
|
RegexExtract,
|
||||||
RegexReplace,
|
RegexReplace,
|
||||||
|
JsonExtractString,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> StringExtension:
|
async def comfy_entrypoint() -> StringExtension:
|
||||||
|
|||||||
@ -35,6 +35,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||||
|
io.Boolean.Input("use_default_template", optional=True, default=True, tooltip="Use the built in system prompt/template if the model has one.", advanced=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.String.Output(display_name="generated_text"),
|
io.String.Output(display_name="generated_text"),
|
||||||
@ -42,9 +43,9 @@ class TextGenerate(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||||
|
|
||||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
|
||||||
|
|
||||||
# Get sampling parameters from dynamic combo
|
# Get sampling parameters from dynamic combo
|
||||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||||
@ -160,12 +161,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
else:
|
else:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
|
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
|
||||||
|
|
||||||
|
|
||||||
class TextgenExtension(ComfyExtension):
|
class TextgenExtension(ComfyExtension):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import comfy.utils
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||||
@ -78,13 +79,15 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
tile = 512
|
tile = 512
|
||||||
overlap = 32
|
overlap = 32
|
||||||
|
|
||||||
|
output_device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
oom = True
|
oom = True
|
||||||
try:
|
try:
|
||||||
while oom:
|
while oom:
|
||||||
try:
|
try:
|
||||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||||
oom = False
|
oom = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
@ -94,7 +97,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
finally:
|
finally:
|
||||||
upscale_model.to("cpu")
|
upscale_model.to("cpu")
|
||||||
|
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
upscale = execute # TODO: remove
|
upscale = execute # TODO: remove
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.18.1"
|
__version__ = "0.19.3"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.18.1"
|
version = "0.19.3"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.42.8
|
comfyui-frontend-package==1.42.11
|
||||||
comfyui-workflow-templates==0.9.44
|
comfyui-workflow-templates==0.9.57
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user