mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Merge branch 'master' into harden-dataset-load
This commit is contained in:
commit
e273558f89
@ -1,5 +1,4 @@
|
|||||||
As of the time of writing this you need this driver for best results:
|
As of the time of writing this you need a recent driver. Updating to the latest driver is recommended.
|
||||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
|
||||||
|
|
||||||
HOW TO RUN:
|
HOW TO RUN:
|
||||||
|
|
||||||
@ -7,9 +6,9 @@ If you have a AMD gpu:
|
|||||||
|
|
||||||
run_amd_gpu.bat
|
run_amd_gpu.bat
|
||||||
|
|
||||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
If you have memory issues you can try enabling the new dynamic memory management by running comfyui with:
|
||||||
|
|
||||||
run_amd_gpu_disable_smart_memory.bat
|
run_amd_gpu_enable_dynamic_vram.bat
|
||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/check-line-endings.yml
vendored
2
.github/workflows/check-line-endings.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check for Windows line endings (CRLF)
|
- name: Check for Windows line endings (CRLF)
|
||||||
run: |
|
run: |
|
||||||
# Get the list of changed files in the PR
|
# Get the list of changed files in the PR
|
||||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci')
|
||||||
|
|
||||||
# Flag to track if CRLF is found
|
# Flag to track if CRLF is found
|
||||||
CRLF_FOUND=false
|
CRLF_FOUND=false
|
||||||
|
|||||||
297
comfy/ldm/ideogram4/model.py
Normal file
297
comfy/ldm/ideogram4/model.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
"""
|
||||||
|
The Ideogram 4 transformer is a NextDiT/Lumina2-family single-stream model
|
||||||
|
consumes Qwen3-VL hidden-state features (concatenated from 13 layers -> 53248 dims)
|
||||||
|
packs ``[text tokens, image tokens]`` into one sequence with block-diagonal segment attention and 3D interleaved MRoPE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
|
from comfy.ldm.lumina.model import FeedForward
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
|
from comfy.text_encoders.llama import apply_rope, precompute_freqs_cis
|
||||||
|
|
||||||
|
# Per-token role indicators
|
||||||
|
SEQUENCE_PADDING_INDICATOR = -1
|
||||||
|
OUTPUT_IMAGE_INDICATOR = 2
|
||||||
|
LLM_TOKEN_INDICATOR = 3
|
||||||
|
# Image grid coordinates are offset so they never collide with text positions
|
||||||
|
IMAGE_POSITION_OFFSET = 65536
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4Attention(nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_heads, eps=1e-5, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = hidden_size // num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
self.qkv = operations.Linear(hidden_size, hidden_size * 3, bias=False, dtype=dtype, device=device)
|
||||||
|
self.norm_q = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.norm_k = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.o = operations.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask, freqs_cis, transformer_options={}):
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
qkv = self.qkv(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
|
||||||
|
q = self.norm_q(q)
|
||||||
|
k = self.norm_k(k)
|
||||||
|
|
||||||
|
# (B, heads, L, head_dim)
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
q, k = apply_rope(q, k, freqs_cis)
|
||||||
|
|
||||||
|
out = optimized_attention_masked(q, k, v, self.num_heads, attn_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
return self.o(out)
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, intermediate_size, num_heads, norm_eps, adaln_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Ideogram4Attention(hidden_size, num_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.feed_forward = FeedForward(
|
||||||
|
dim=hidden_size, hidden_dim=intermediate_size, multiple_of=1, ffn_dim_multiplier=None,
|
||||||
|
operation_settings={"operations": operations, "dtype": dtype, "device": device},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.ffn_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.attention_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.ffn_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.adaln_modulation = operations.Linear(adaln_dim, 4 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask, freqs_cis, adaln_input, transformer_options={}):
|
||||||
|
mod = self.adaln_modulation(adaln_input)
|
||||||
|
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.chunk(4, dim=-1)
|
||||||
|
gate_msa = torch.tanh(gate_msa)
|
||||||
|
gate_mlp = torch.tanh(gate_mlp)
|
||||||
|
scale_msa = 1.0 + scale_msa
|
||||||
|
scale_mlp = 1.0 + scale_mlp
|
||||||
|
|
||||||
|
attn_out = self.attention(self.attention_norm1(x) * scale_msa, attn_mask, freqs_cis, transformer_options=transformer_options)
|
||||||
|
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||||
|
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _sinusoidal_embedding(t, dim, scale=1e4):
|
||||||
|
t = t.to(torch.float32)
|
||||||
|
half = dim // 2
|
||||||
|
freq = math.log(scale) / (half - 1)
|
||||||
|
freq = torch.exp(torch.arange(half, dtype=torch.float32, device=t.device) * -freq)
|
||||||
|
emb = t.unsqueeze(-1) * freq
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||||
|
if dim % 2 == 1:
|
||||||
|
emb = F.pad(emb, (0, 1))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4EmbedScalar(nn.Module):
|
||||||
|
def __init__(self, dim, input_range=(0.0, 1.0), dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.range_min, self.range_max = input_range
|
||||||
|
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||||
|
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||||
|
emb = emb.to(self.mlp_in.weight.dtype)
|
||||||
|
emb = F.silu(self.mlp_in(emb))
|
||||||
|
return self.mlp_out(emb)
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4FinalLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size, out_channels, adaln_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaln_modulation = operations.Linear(adaln_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
scale = 1.0 + self.adaln_modulation(F.silu(c))
|
||||||
|
return self.linear(self.norm_final(x) * scale)
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4Transformer(nn.Module):
|
||||||
|
"""A single Ideogram 4 backbone operating on a packed token sequence."""
|
||||||
|
|
||||||
|
def __init__(self, emb_dim, num_layers, num_heads, intermediate_size, adaln_dim,
|
||||||
|
in_channels, llm_features_dim, rope_theta, mrope_section, norm_eps,
|
||||||
|
dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = emb_dim // num_heads
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.mrope_section = tuple(mrope_section)
|
||||||
|
|
||||||
|
self.input_proj = operations.Linear(in_channels, emb_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.llm_cond_norm = operations.RMSNorm(llm_features_dim, eps=1e-6, elementwise_affine=True, dtype=dtype, device=device)
|
||||||
|
self.llm_cond_proj = operations.Linear(llm_features_dim, emb_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.t_embedding = Ideogram4EmbedScalar(emb_dim, input_range=(0.0, 1.0), dtype=dtype, device=device, operations=operations)
|
||||||
|
self.adaln_proj = operations.Linear(emb_dim, adaln_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.embed_image_indicator = operations.Embedding(2, emb_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Ideogram4TransformerBlock(emb_dim, intermediate_size, num_heads, norm_eps, adaln_dim,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.final_layer = Ideogram4FinalLayer(emb_dim, in_channels, adaln_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def _backbone(self, llm_features, x, t, position_ids, attn_mask, indicator, transformer_options={}):
|
||||||
|
indicator = indicator.to(torch.long)
|
||||||
|
output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(x.dtype).unsqueeze(-1)
|
||||||
|
|
||||||
|
x = x * output_image_mask
|
||||||
|
h = self.input_proj(x) * output_image_mask
|
||||||
|
|
||||||
|
t_cond = self.t_embedding(t)
|
||||||
|
if t.dim() == 1:
|
||||||
|
t_cond = t_cond.unsqueeze(1)
|
||||||
|
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||||
|
|
||||||
|
# h is zero on the text rows (content lives only on image rows), add writes the text features in place
|
||||||
|
if llm_features is not None:
|
||||||
|
L_text = llm_features.shape[1]
|
||||||
|
text_mask = (indicator[:, :L_text] == LLM_TOKEN_INDICATOR).to(x.dtype).unsqueeze(-1)
|
||||||
|
llm = self.llm_cond_norm(llm_features * text_mask)
|
||||||
|
llm = self.llm_cond_proj(llm) * text_mask
|
||||||
|
h[:, :L_text] = h[:, :L_text] + llm
|
||||||
|
|
||||||
|
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
|
||||||
|
|
||||||
|
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
|
||||||
|
freqs_cis = precompute_freqs_cis(
|
||||||
|
self.head_dim, position_ids[0].transpose(0, 1), self.rope_theta,
|
||||||
|
rope_dims=self.mrope_section, interleaved_mrope=True, device=position_ids.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
||||||
|
attn_mask = torch.zeros_like(attn_mask, dtype=h.dtype).masked_fill_(~attn_mask, -torch.finfo(h.dtype).max)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
h = layer(h, attn_mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
return self.final_layer(h, adaln_input)
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
||||||
|
"""Ideogram 4 single-stream DiT.
|
||||||
|
|
||||||
|
Runs a packed ``[text, image]`` sequence when text context is supplied, or an image-only sequence when ``context is None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_model=None, in_channels=128, num_layers=34, num_attention_heads=18, attention_head_dim=256, intermediate_size=12288,
|
||||||
|
adaln_dim=512, llm_features_dim=53248, rope_theta=5000000, mrope_section=(24, 20, 20), norm_eps=1e-5,
|
||||||
|
dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
emb_dim = num_attention_heads * attention_head_dim
|
||||||
|
super().__init__(
|
||||||
|
emb_dim=emb_dim, num_layers=num_layers, num_heads=num_attention_heads,
|
||||||
|
intermediate_size=intermediate_size, adaln_dim=adaln_dim, in_channels=in_channels,
|
||||||
|
llm_features_dim=llm_features_dim, rope_theta=rope_theta, mrope_section=mrope_section,
|
||||||
|
norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
# 128-dim token = patch (2x2) * ae_channels (32).
|
||||||
|
self.patch_size = 2
|
||||||
|
self.ae_channels = in_channels // (self.patch_size * self.patch_size)
|
||||||
|
|
||||||
|
def _img_to_tokens(self, x):
|
||||||
|
B, C, gh, gw = x.shape
|
||||||
|
x = x.view(B, self.ae_channels, self.patch_size, self.patch_size, gh, gw)
|
||||||
|
x = x.permute(0, 4, 5, 2, 3, 1) # (B, gh, gw, pi, pj, c)
|
||||||
|
return x.reshape(B, gh * gw, C)
|
||||||
|
|
||||||
|
def _tokens_to_img(self, tokens, gh, gw):
|
||||||
|
B = tokens.shape[0]
|
||||||
|
C = tokens.shape[-1]
|
||||||
|
x = tokens.reshape(B, gh, gw, self.patch_size, self.patch_size, self.ae_channels)
|
||||||
|
x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw)
|
||||||
|
return x.reshape(B, C, gh, gw)
|
||||||
|
|
||||||
|
def _image_position_ids(self, gh, gw, device):
|
||||||
|
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1)
|
||||||
|
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1)
|
||||||
|
t_idx = torch.zeros_like(h_idx)
|
||||||
|
return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3)
|
||||||
|
|
||||||
|
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
|
||||||
|
B = x_chunk.shape[0]
|
||||||
|
device = x_chunk.device
|
||||||
|
img_tokens = self._img_to_tokens(x_chunk)
|
||||||
|
L_img = img_tokens.shape[1]
|
||||||
|
L_text = context_chunk.shape[1]
|
||||||
|
L = L_text + L_img
|
||||||
|
latent_dim = img_tokens.shape[-1]
|
||||||
|
|
||||||
|
x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device)
|
||||||
|
x_full[:, L_text:] = img_tokens
|
||||||
|
|
||||||
|
text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3)
|
||||||
|
img_pos = self._image_position_ids(gh, gw, device)
|
||||||
|
position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3)
|
||||||
|
|
||||||
|
indicator = torch.empty(B, L, dtype=torch.long, device=device)
|
||||||
|
indicator[:, :L_text] = LLM_TOKEN_INDICATOR
|
||||||
|
indicator[:, L_text:] = OUTPUT_IMAGE_INDICATOR
|
||||||
|
|
||||||
|
attn_mask = None
|
||||||
|
if attn_mask_chunk is not None:
|
||||||
|
segment_ids = torch.ones(B, L, dtype=torch.long, device=device)
|
||||||
|
pad = (attn_mask_chunk == 0)
|
||||||
|
segment_ids[:, :L_text][pad] = SEQUENCE_PADDING_INDICATOR
|
||||||
|
indicator[:, :L_text][pad] = 0
|
||||||
|
# Block-diagonal mask from segment ids: (B, 1, L, L), True = attend.
|
||||||
|
attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
|
||||||
|
|
||||||
|
out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
return self._tokens_to_img(out[:, L_text:], gh, gw)
|
||||||
|
|
||||||
|
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
|
||||||
|
B = x_chunk.shape[0]
|
||||||
|
device = x_chunk.device
|
||||||
|
img_tokens = self._img_to_tokens(x_chunk)
|
||||||
|
L_img = img_tokens.shape[1]
|
||||||
|
|
||||||
|
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
|
||||||
|
indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# Image-only sequence is a single segment -> no mask, full attention, no LLM context.
|
||||||
|
out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
|
||||||
|
return self._tokens_to_img(out, gh, gw)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||||
|
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
|
bs, c, gh, gw = x.shape
|
||||||
|
|
||||||
|
timesteps = 1.0 - timesteps
|
||||||
|
|
||||||
|
# unconditional pass
|
||||||
|
if context is None:
|
||||||
|
return -self._run_image_only(x, timesteps, gh, gw, transformer_options)
|
||||||
|
|
||||||
|
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options)
|
||||||
@ -55,6 +55,7 @@ import comfy.ldm.pixeldit.pid
|
|||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
|
import comfy.ldm.ideogram4.model
|
||||||
import comfy.ldm.kandinsky5.model
|
import comfy.ldm.kandinsky5.model
|
||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
import comfy.ldm.ace.ace_step15
|
import comfy.ldm.ace.ace_step15
|
||||||
@ -2018,6 +2019,21 @@ class QwenImage(BaseModel):
|
|||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Ideogram4(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
if torch.numel(attention_mask) != attention_mask.sum():
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|
||||||
class HunyuanImage21(BaseModel):
|
class HunyuanImage21(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
|
|||||||
@ -815,6 +815,13 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["default_ref_method"] = "negative_index"
|
dit_config["default_ref_method"] = "negative_index"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "ideogram4"
|
||||||
|
dit_config["in_channels"] = state_dict['{}input_proj.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||||
|
|||||||
@ -651,8 +651,7 @@ def ensure_pin_budget(size, evict_active=False):
|
|||||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||||
|
|
||||||
def ensure_pin_registerable(size, evict_active=True):
|
def free_registrations(shortfall, evict_active=True):
|
||||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
|
||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
return False
|
return False
|
||||||
if shortfall <= 0:
|
if shortfall <= 0:
|
||||||
@ -674,6 +673,9 @@ def ensure_pin_registerable(size, evict_active=True):
|
|||||||
return True
|
return True
|
||||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||||
|
|
||||||
|
def ensure_pin_registerable(size, evict_active=True):
|
||||||
|
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model: ModelPatcher):
|
def __init__(self, model: ModelPatcher):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
|
|||||||
@ -89,13 +89,26 @@ def pin_memory(module, subset="weights", size=None):
|
|||||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||||
return _steal_pin(module, stack, buckets, size, priority)
|
return _steal_pin(module, stack, buckets, size, priority)
|
||||||
|
|
||||||
|
extended = False
|
||||||
try:
|
try:
|
||||||
hostbuf.extend(size=size)
|
hostbuf.extend(size=size, register=False)
|
||||||
|
extended = True
|
||||||
|
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||||
|
pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||||
|
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||||
|
comfy.model_management.discard_cuda_async_error()
|
||||||
|
comfy.model_management.free_registrations(size)
|
||||||
|
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||||
|
comfy.model_management.discard_cuda_async_error()
|
||||||
|
del pin
|
||||||
|
hostbuf.truncate(offset, do_unregister=False)
|
||||||
|
return _steal_pin(module, stack, buckets, size, priority)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
if extended:
|
||||||
|
hostbuf.truncate(offset, do_unregister=False)
|
||||||
return _steal_pin(module, stack, buckets, size, priority)
|
return _steal_pin(module, stack, buckets, size, priority)
|
||||||
|
|
||||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
module._pin = pin
|
||||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
|
||||||
stack.append((module, offset))
|
stack.append((module, offset))
|
||||||
module._pin_registered = True
|
module._pin_registered = True
|
||||||
module._pin_stack_index = len(stack) - 1
|
module._pin_stack_index = len(stack) - 1
|
||||||
|
|||||||
10
comfy/sd.py
10
comfy/sd.py
@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2
|
|||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
|
import comfy.text_encoders.ideogram4
|
||||||
import comfy.text_encoders.ovis
|
import comfy.text_encoders.ovis
|
||||||
import comfy.text_encoders.kandinsky5
|
import comfy.text_encoders.kandinsky5
|
||||||
import comfy.text_encoders.jina_clip_2
|
import comfy.text_encoders.jina_clip_2
|
||||||
@ -1298,6 +1299,7 @@ class CLIPType(Enum):
|
|||||||
COGVIDEOX = 27
|
COGVIDEOX = 27
|
||||||
LENS = 28
|
LENS = 28
|
||||||
PIXELDIT = 29
|
PIXELDIT = 29
|
||||||
|
IDEOGRAM4 = 30
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1596,8 +1598,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||||
elif te_model == TEModel.QWEN3_8B:
|
elif te_model == TEModel.QWEN3_8B:
|
||||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
|
if clip_type == CLIPType.IDEOGRAM4:
|
||||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
|
clip_target.clip = comfy.text_encoders.ideogram4.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Tokenizer
|
||||||
|
else:
|
||||||
|
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
|
||||||
elif te_model == TEModel.JINA_CLIP_2:
|
elif te_model == TEModel.JINA_CLIP_2:
|
||||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import comfy.text_encoders.qwen_image
|
|||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
import comfy.text_encoders.kandinsky5
|
import comfy.text_encoders.kandinsky5
|
||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
|
import comfy.text_encoders.ideogram4
|
||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
@ -1746,6 +1747,44 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
class Ideogram4(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "ideogram4",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 11.6
|
||||||
|
|
||||||
|
unet_extra_config = {
|
||||||
|
"num_attention_heads": 18,
|
||||||
|
"attention_head_dim": 256,
|
||||||
|
"intermediate_size": 12288,
|
||||||
|
"adaln_dim": 512,
|
||||||
|
"llm_features_dim": 53248,
|
||||||
|
"rope_theta": 5000000,
|
||||||
|
"mrope_section": [24, 20, 20],
|
||||||
|
"norm_eps": 1e-5,
|
||||||
|
}
|
||||||
|
latent_format = latent_formats.Flux2
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Ideogram4(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
|
||||||
|
|
||||||
class QwenImage(supported_models_base.BASE):
|
class QwenImage(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "qwen_image",
|
"image_model": "qwen_image",
|
||||||
@ -2233,6 +2272,7 @@ models = [
|
|||||||
ACEStep15,
|
ACEStep15,
|
||||||
Omnigen2,
|
Omnigen2,
|
||||||
QwenImage,
|
QwenImage,
|
||||||
|
Ideogram4,
|
||||||
Flux2,
|
Flux2,
|
||||||
Lens,
|
Lens,
|
||||||
Kandinsky5Image,
|
Kandinsky5Image,
|
||||||
|
|||||||
77
comfy/text_encoders/ideogram4.py
Normal file
77
comfy/text_encoders/ideogram4.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Ideogram 4 text encoder: Qwen3-VL-8B language model, 13-layer tap.
|
||||||
|
|
||||||
|
Ideogram 4 conditions on the concatenation of hidden states from 13 layers of
|
||||||
|
Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
from comfy import sd1_clip
|
||||||
|
|
||||||
|
# Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1.
|
||||||
|
IDEOGRAM4_TAP_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 36]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3VLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory,
|
||||||
|
embedding_size=4096, embedding_key='qwen3vl_8b', tokenizer_class=Qwen2Tokenizer,
|
||||||
|
has_start_token=False, has_end_token=False, pad_to_max_length=False,
|
||||||
|
max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||||
|
name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer)
|
||||||
|
|
||||||
|
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||||
|
if llama_template is None:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6)
|
||||||
|
# final_norm/lm_head off -> Ideogram only reads raw tapped hidden states
|
||||||
|
QWEN3VL_8B_CONFIG = {"rope_theta": 5000000.0, "final_norm": False, "lm_head": False}
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3VL8BModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None,
|
||||||
|
textmodel_json_config=dict(QWEN3VL_8B_CONFIG),
|
||||||
|
dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False,
|
||||||
|
model_class=comfy.text_encoders.llama.Qwen3_8B,
|
||||||
|
enable_attention_masks=attention_mask, return_attention_masks=attention_mask,
|
||||||
|
model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4TEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order.
|
||||||
|
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps).
|
||||||
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class Ideogram4TEModel_(Ideogram4TEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return Ideogram4TEModel_
|
||||||
@ -755,6 +755,18 @@ class File3DKSPLAT(ComfyTypeIO):
|
|||||||
Type = File3D
|
Type = File3D
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="FILE_3D_SPLAT_ANY")
|
||||||
|
class File3DSplatAny(ComfyTypeIO):
|
||||||
|
"""General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat)."""
|
||||||
|
Type = File3D
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY")
|
||||||
|
class File3DPointCloudAny(ComfyTypeIO):
|
||||||
|
"""General point cloud file type - accepts any supported point cloud container (currently .ply)."""
|
||||||
|
Type = File3D
|
||||||
|
|
||||||
|
|
||||||
@comfytype(io_type="HOOKS")
|
@comfytype(io_type="HOOKS")
|
||||||
class Hooks(ComfyTypeIO):
|
class Hooks(ComfyTypeIO):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -2336,6 +2348,8 @@ __all__ = [
|
|||||||
"File3DSPLAT",
|
"File3DSPLAT",
|
||||||
"File3DSPZ",
|
"File3DSPZ",
|
||||||
"File3DKSPLAT",
|
"File3DKSPLAT",
|
||||||
|
"File3DSplatAny",
|
||||||
|
"File3DPointCloudAny",
|
||||||
"Hooks",
|
"Hooks",
|
||||||
"HookKeyframes",
|
"HookKeyframes",
|
||||||
"TimestepsRange",
|
"TimestepsRange",
|
||||||
|
|||||||
@ -285,7 +285,7 @@ class AudioSaveHelper:
|
|||||||
results = []
|
results = []
|
||||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
file = f"{filename_with_batch_num}_{counter:05}.{format}"
|
||||||
output_path = os.path.join(full_output_folder, file)
|
output_path = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
# Use original sample rate initially
|
# Use original sample rate initially
|
||||||
|
|||||||
@ -43,6 +43,7 @@ class BFLFluxEraseRequest(BaseModel):
|
|||||||
"white (255) marks areas to remove, black (0) marks areas to preserve.",
|
"white (255) marks areas to remove, black (0) marks areas to preserve.",
|
||||||
)
|
)
|
||||||
dilate_pixels: int = Field(10)
|
dilate_pixels: int = Field(10)
|
||||||
|
seed: int | None = Field(None)
|
||||||
output_format: str = Field("png")
|
output_format: str = Field("png")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel):
|
|||||||
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||||
status: str = Field(...)
|
status: str = Field(...)
|
||||||
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaVideoGreenScreenRequest(BaseModel):
|
||||||
|
video: str = Field(..., description="Publicly accessible URL of the input video.")
|
||||||
|
green_shade: str = Field(
|
||||||
|
default="broadcast_green",
|
||||||
|
description="Solid chroma-key shade applied behind the foreground "
|
||||||
|
"(broadcast_green, chroma_green, or blue_screen).",
|
||||||
|
)
|
||||||
|
output_container_and_codec: str = Field(...)
|
||||||
|
preserve_audio: bool = Field(True)
|
||||||
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaVideoReplaceBackgroundRequest(BaseModel):
|
||||||
|
video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.")
|
||||||
|
background_url: str = Field(
|
||||||
|
...,
|
||||||
|
description="Publicly accessible URL of the background image or video to composite behind "
|
||||||
|
"the foreground. Stretched to the foreground frame; match its aspect ratio for "
|
||||||
|
"undistorted results.",
|
||||||
|
)
|
||||||
|
output_container_and_codec: str = Field(...)
|
||||||
|
preserve_audio: bool = Field(True)
|
||||||
|
seed: int = Field(...)
|
||||||
|
|||||||
@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel):
|
|||||||
startOffset: GeminiOffset | None = Field(None)
|
startOffset: GeminiOffset | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiThinkingConfig(BaseModel):
|
||||||
|
includeThoughts: bool | None = Field(None)
|
||||||
|
thinkingLevel: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class GeminiGenerationConfig(BaseModel):
|
class GeminiGenerationConfig(BaseModel):
|
||||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
maxOutputTokens: int | None = Field(None, ge=16, le=65536)
|
||||||
seed: int | None = Field(None)
|
seed: int | None = Field(None)
|
||||||
stopSequences: list[str] | None = Field(None)
|
stopSequences: list[str] | None = Field(None)
|
||||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||||
topK: int | None = Field(None, ge=1)
|
topK: int | None = Field(None, ge=1)
|
||||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||||
|
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageOutputOptions(BaseModel):
|
class GeminiImageOutputOptions(BaseModel):
|
||||||
@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel):
|
|||||||
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||||
|
|
||||||
|
|
||||||
class GeminiThinkingConfig(BaseModel):
|
|
||||||
includeThoughts: bool | None = Field(None)
|
|
||||||
thinkingLevel: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||||
responseModalities: list[str] | None = Field(None)
|
responseModalities: list[str] | None = Field(None)
|
||||||
imageConfig: GeminiImageConfig | None = Field(None)
|
imageConfig: GeminiImageConfig | None = Field(None)
|
||||||
|
|||||||
@ -290,3 +290,19 @@ class IdeogramV3Request(BaseModel):
|
|||||||
None,
|
None,
|
||||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IdeogramV4Request(BaseModel):
|
||||||
|
text_prompt: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Natural-language prompt; Magic Prompt is applied automatically. "
|
||||||
|
"Supply exactly one of text_prompt or json_prompt.",
|
||||||
|
)
|
||||||
|
json_prompt: dict[str, Any] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Structured V4 prompt object consumed directly (disables Magic Prompt). "
|
||||||
|
"Supply exactly one of text_prompt or json_prompt.",
|
||||||
|
)
|
||||||
|
resolution: str | None = Field(None, description="Output resolution in WIDTHxHEIGHT (e.g. '2048x2048').")
|
||||||
|
rendering_speed: str | None = Field(None, description="Rendering speed: 'TURBO', 'DEFAULT', or 'QUALITY'.")
|
||||||
|
enable_copyright_detection: bool | None = Field(None, description="Opt into post-generation copyright detection.")
|
||||||
|
|||||||
@ -534,6 +534,15 @@ class FluxEraseNode(IO.ComfyNode):
|
|||||||
max=25,
|
max=25,
|
||||||
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
|
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
|
||||||
),
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for creating the noise.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[IO.Image.Output()],
|
outputs=[IO.Image.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -553,6 +562,7 @@ class FluxEraseNode(IO.ComfyNode):
|
|||||||
image: Input.Image,
|
image: Input.Image,
|
||||||
mask: Input.Image,
|
mask: Input.Image,
|
||||||
dilate_pixels: int = 10,
|
dilate_pixels: int = 10,
|
||||||
|
seed: int = 0,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_image_dimensions(image, min_width=256, min_height=256)
|
validate_image_dimensions(image, min_width=256, min_height=256)
|
||||||
mask = resize_mask_to_image(mask, image)
|
mask = resize_mask_to_image(mask, image)
|
||||||
@ -565,6 +575,7 @@ class FluxEraseNode(IO.ComfyNode):
|
|||||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||||
mask=mask,
|
mask=mask,
|
||||||
dilate_pixels=dilate_pixels,
|
dilate_pixels=dilate_pixels,
|
||||||
|
seed=seed,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,19 @@
|
|||||||
|
import av
|
||||||
|
import torch
|
||||||
|
from av.codec import CodecContext
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api_nodes.apis.bria import (
|
from comfy_api_nodes.apis.bria import (
|
||||||
BriaEditImageRequest,
|
BriaEditImageRequest,
|
||||||
|
BriaImageEditResponse,
|
||||||
BriaRemoveBackgroundRequest,
|
BriaRemoveBackgroundRequest,
|
||||||
BriaRemoveBackgroundResponse,
|
BriaRemoveBackgroundResponse,
|
||||||
BriaRemoveVideoBackgroundRequest,
|
BriaRemoveVideoBackgroundRequest,
|
||||||
BriaRemoveVideoBackgroundResponse,
|
BriaRemoveVideoBackgroundResponse,
|
||||||
BriaImageEditResponse,
|
|
||||||
BriaStatusResponse,
|
BriaStatusResponse,
|
||||||
|
BriaVideoGreenScreenRequest,
|
||||||
|
BriaVideoReplaceBackgroundRequest,
|
||||||
InputModerationSettings,
|
InputModerationSettings,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
@ -316,6 +321,248 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
class BriaVideoGreenScreen(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="BriaVideoGreenScreen",
|
||||||
|
display_name="Bria Video Green Screen",
|
||||||
|
category="partner/video/Bria",
|
||||||
|
description="Replace a video's background with a solid chroma-key screen using Bria.",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input("video"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"green_shade",
|
||||||
|
options=["broadcast_green", "chroma_green", "blue_screen"],
|
||||||
|
tooltip="Solid chroma-key shade applied behind the foreground: "
|
||||||
|
"broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: Input.Video,
|
||||||
|
green_shade: str,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_video_duration(video, max_duration=60.0)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"),
|
||||||
|
data=BriaVideoGreenScreenRequest(
|
||||||
|
video=await upload_video_to_comfyapi(cls, video),
|
||||||
|
green_shade=green_shade,
|
||||||
|
output_container_and_codec="mp4_h264",
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=BriaStatusResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
response_model=BriaRemoveVideoBackgroundResponse,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="BriaVideoReplaceBackground",
|
||||||
|
display_name="Bria Video Replace Background",
|
||||||
|
category="partner/video/Bria",
|
||||||
|
description="Replace a video's background with a supplied image or video using Bria. "
|
||||||
|
"The output keeps the foreground's resolution and frame rate; a background with a "
|
||||||
|
"different aspect ratio is stretched to fit, so match it for undistorted results.",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input("video", tooltip="Foreground video whose background is replaced."),
|
||||||
|
IO.Image.Input(
|
||||||
|
"background_image",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Background image to composite behind the foreground. "
|
||||||
|
"Provide either a background image or a background video, not both.",
|
||||||
|
),
|
||||||
|
IO.Video.Input(
|
||||||
|
"background_video",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Background video to composite behind the foreground. "
|
||||||
|
"Provide either a background image or a background video, not both.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: Input.Video,
|
||||||
|
seed: int,
|
||||||
|
background_image: Input.Image | None = None,
|
||||||
|
background_video: Input.Video | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
if (background_image is None) == (background_video is None):
|
||||||
|
raise ValueError("Provide either a background image or a background video, not both.")
|
||||||
|
validate_video_duration(video, max_duration=60.0)
|
||||||
|
if background_video is not None:
|
||||||
|
validate_video_duration(background_video, max_duration=60.0)
|
||||||
|
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
||||||
|
else:
|
||||||
|
background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background")
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
||||||
|
data=BriaVideoReplaceBackgroundRequest(
|
||||||
|
video=await upload_video_to_comfyapi(cls, video),
|
||||||
|
background_url=background_url,
|
||||||
|
output_container_and_codec="mp4_h264",
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=BriaStatusResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
response_model=BriaRemoveVideoBackgroundResponse,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||||
|
|
||||||
|
|
||||||
|
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
|
||||||
|
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
|
||||||
|
|
||||||
|
VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames
|
||||||
|
are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W]
|
||||||
|
following the Load Image convention (1 = transparent) for compositing or Save WEBM.
|
||||||
|
"""
|
||||||
|
rgb_frames: list[torch.Tensor] = []
|
||||||
|
alpha_frames: list[torch.Tensor] = []
|
||||||
|
with av.open(video.get_stream_source(), mode="r") as container:
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None
|
||||||
|
for packet in container.demux(stream):
|
||||||
|
for frame in (decoder.decode(packet) if decoder is not None else packet.decode()):
|
||||||
|
rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0
|
||||||
|
rgb_frames.append(rgba[..., :3])
|
||||||
|
alpha_frames.append(rgba[..., 3])
|
||||||
|
images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3)
|
||||||
|
mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64))
|
||||||
|
return images, mask
|
||||||
|
|
||||||
|
|
||||||
|
class BriaTransparentVideoBackground(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="BriaTransparentVideoBackground",
|
||||||
|
display_name="Bria Remove Video Background (Transparent)",
|
||||||
|
category="partner/video/Bria",
|
||||||
|
description="Remove the background from a video using Bria and return the cut-out frames "
|
||||||
|
"plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to "
|
||||||
|
"write a transparent video.",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input("video"),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(display_name="images"),
|
||||||
|
IO.Mask.Output(display_name="mask"),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: Input.Video,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_video_duration(video, max_duration=60.0)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
|
||||||
|
data=BriaRemoveVideoBackgroundRequest(
|
||||||
|
video=await upload_video_to_comfyapi(cls, video),
|
||||||
|
background_color="Transparent",
|
||||||
|
output_container_and_codec="webm_vp9",
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=BriaStatusResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
response_model=BriaRemoveVideoBackgroundResponse,
|
||||||
|
)
|
||||||
|
video_out = await download_url_to_video_output(response.result.video_url)
|
||||||
|
images, mask = _video_to_images_and_mask(video_out)
|
||||||
|
return IO.NodeOutput(images, mask)
|
||||||
|
|
||||||
|
|
||||||
class BriaExtension(ComfyExtension):
|
class BriaExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -323,6 +570,9 @@ class BriaExtension(ComfyExtension):
|
|||||||
BriaImageEditNode,
|
BriaImageEditNode,
|
||||||
BriaRemoveImageBackground,
|
BriaRemoveImageBackground,
|
||||||
BriaRemoveVideoBackground,
|
BriaRemoveVideoBackground,
|
||||||
|
BriaVideoGreenScreen,
|
||||||
|
# BriaVideoReplaceBackground, # server returns Status 500 when we pass background video
|
||||||
|
BriaTransparentVideoBackground,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from io import BytesIO
|
|||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy.utils import common_upscale
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||||
from comfy_api_nodes.apis.bytedance import (
|
from comfy_api_nodes.apis.bytedance import (
|
||||||
RECOMMENDED_PRESETS,
|
RECOMMENDED_PRESETS,
|
||||||
@ -131,6 +132,44 @@ def _prepare_seedance_image(image: Input.Image) -> Input.Image:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump.
|
||||||
|
SEEDANCE2_RATIO_WH = {
|
||||||
|
"16:9": (16, 9),
|
||||||
|
"4:3": (4, 3),
|
||||||
|
"1:1": (1, 1),
|
||||||
|
"3:4": (3, 4),
|
||||||
|
"9:16": (9, 16),
|
||||||
|
"21:9": (21, 9),
|
||||||
|
}
|
||||||
|
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
|
||||||
|
|
||||||
|
|
||||||
|
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
|
||||||
|
"""Exact supported output (width, height) for (resolution, ratio).
|
||||||
|
|
||||||
|
The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio
|
||||||
|
"adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped
|
||||||
|
to the nearest supported ratio, so the output keeps the frame's orientation.
|
||||||
|
"""
|
||||||
|
short = SEEDANCE2_RES_SHORT_SIDE[resolution]
|
||||||
|
if ratio not in SEEDANCE2_RATIO_WH:
|
||||||
|
aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C)
|
||||||
|
ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect))
|
||||||
|
rw, rh = SEEDANCE2_RATIO_WH[ratio]
|
||||||
|
if rw >= rh: # landscape or square: shorter side is the height
|
||||||
|
out_w, out_h = round(short * rw / rh), short
|
||||||
|
else: # portrait: shorter side is the width
|
||||||
|
out_w, out_h = short, round(short * rh / rw)
|
||||||
|
return out_w - out_w % 2, out_h - out_h % 2
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor:
|
||||||
|
"""Center-crop to the target aspect and resize to exactly width x height (lanczos)."""
|
||||||
|
samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W)
|
||||||
|
resized = common_upscale(samples, width, height, "lanczos", "center")
|
||||||
|
return resized.movedim(1, -1)
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_reference_assets(
|
async def _resolve_reference_assets(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
asset_ids: list[str],
|
asset_ids: list[str],
|
||||||
@ -1790,10 +1829,28 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
if last_frame is not None and last_frame_asset_id:
|
if last_frame is not None and last_frame_asset_id:
|
||||||
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
||||||
|
|
||||||
if first_frame is not None:
|
request_ratio = model["ratio"]
|
||||||
first_frame = _prepare_seedance_image(first_frame)
|
if first_frame_asset_id or last_frame_asset_id:
|
||||||
if last_frame is not None:
|
if first_frame is not None:
|
||||||
last_frame = _prepare_seedance_image(last_frame)
|
first_frame = _prepare_seedance_image(first_frame)
|
||||||
|
if last_frame is not None:
|
||||||
|
last_frame = _prepare_seedance_image(last_frame)
|
||||||
|
else:
|
||||||
|
# The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive")
|
||||||
|
# only applies to local image inputs we can resize.
|
||||||
|
request_ratio = "adaptive"
|
||||||
|
target_dims: tuple[int, int] | None = None
|
||||||
|
if first_frame is not None:
|
||||||
|
validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||||
|
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame)
|
||||||
|
first_frame = _resize_to_exact(first_frame, *target_dims)
|
||||||
|
if last_frame is not None:
|
||||||
|
validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
validate_image_dimensions(last_frame, min_width=300, min_height=300)
|
||||||
|
if target_dims is None:
|
||||||
|
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame)
|
||||||
|
last_frame = _resize_to_exact(last_frame, *target_dims)
|
||||||
|
|
||||||
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
||||||
image_assets: dict[str, str] = {}
|
image_assets: dict[str, str] = {}
|
||||||
@ -1844,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
content=content,
|
content=content,
|
||||||
generate_audio=model["generate_audio"],
|
generate_audio=model["generate_audio"],
|
||||||
resolution=model["resolution"],
|
resolution=model["resolution"],
|
||||||
ratio=model["ratio"],
|
ratio=request_ratio,
|
||||||
duration=model["duration"],
|
duration=model["duration"],
|
||||||
seed=seed,
|
seed=seed,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import os
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -19,6 +19,7 @@ from comfy_api_nodes.apis.gemini import (
|
|||||||
GeminiContent,
|
GeminiContent,
|
||||||
GeminiFileData,
|
GeminiFileData,
|
||||||
GeminiGenerateContentRequest,
|
GeminiGenerateContentRequest,
|
||||||
|
GeminiGenerationConfig,
|
||||||
GeminiGenerateContentResponse,
|
GeminiGenerateContentResponse,
|
||||||
GeminiImageConfig,
|
GeminiImageConfig,
|
||||||
GeminiImageGenerateContentRequest,
|
GeminiImageGenerateContentRequest,
|
||||||
@ -40,13 +41,18 @@ from comfy_api_nodes.util import (
|
|||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_audio_to_comfyapi,
|
||||||
|
upload_image_to_comfyapi,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
video_to_base64_string,
|
video_to_base64_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||||
|
GEMINI_URL_INPUT_BUDGET = 10
|
||||||
|
GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024
|
||||||
GEMINI_IMAGE_SYS_PROMPT = (
|
GEMINI_IMAGE_SYS_PROMPT = (
|
||||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||||
"Interpret all user input—regardless of "
|
"Interpret all user input—regardless of "
|
||||||
@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
|||||||
return final_price / 1_000_000.0
|
return final_price / 1_000_000.0
|
||||||
|
|
||||||
|
|
||||||
|
def create_video_parts(video_input: Input.Video) -> list[GeminiPart]:
|
||||||
|
"""Convert a single video input to Gemini API compatible parts (inline MP4/H.264)."""
|
||||||
|
base_64_string = video_to_base64_string(
|
||||||
|
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
GeminiPart(
|
||||||
|
inlineData=GeminiInlineData(
|
||||||
|
mimeType=GeminiMimeType.video_mp4,
|
||||||
|
data=base_64_string,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]:
|
||||||
|
"""Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item)."""
|
||||||
|
audio_parts: list[GeminiPart] = []
|
||||||
|
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||||
|
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||||
|
audio_at_index = Input.Audio(
|
||||||
|
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||||
|
sample_rate=audio_input["sample_rate"],
|
||||||
|
)
|
||||||
|
# Convert to MP3 format for compatibility with Gemini API
|
||||||
|
audio_bytes = audio_to_base64_string(
|
||||||
|
audio_at_index,
|
||||||
|
container_format="mp3",
|
||||||
|
codec_name="libmp3lame",
|
||||||
|
)
|
||||||
|
audio_parts.append(
|
||||||
|
GeminiPart(
|
||||||
|
inlineData=GeminiInlineData(
|
||||||
|
mimeType=GeminiMimeType.audio_mp3,
|
||||||
|
data=audio_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return audio_parts
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]:
|
||||||
|
"""Expand any batched image tensors into individual (H, W, C) frames, preserving order."""
|
||||||
|
frames: list[torch.Tensor] = []
|
||||||
|
for img in images:
|
||||||
|
if len(img.shape) == 4:
|
||||||
|
frames.extend(img[i] for i in range(img.shape[0]))
|
||||||
|
else:
|
||||||
|
frames.append(img)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]:
|
||||||
|
"""Expand any batched audio inputs into individual single-clip audio inputs, preserving order."""
|
||||||
|
clips: list[Input.Audio] = []
|
||||||
|
for audio in audios:
|
||||||
|
waveform = audio["waveform"]
|
||||||
|
for i in range(waveform.shape[0]):
|
||||||
|
clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"]))
|
||||||
|
return clips
|
||||||
|
|
||||||
|
|
||||||
|
async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart:
|
||||||
|
"""Upload a single media unit to ComfyAPI storage and return a fileData (URL) part."""
|
||||||
|
if kind == "image":
|
||||||
|
url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image")
|
||||||
|
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url))
|
||||||
|
if kind == "audio":
|
||||||
|
url = await upload_audio_to_comfyapi(
|
||||||
|
cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3"
|
||||||
|
)
|
||||||
|
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url))
|
||||||
|
url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video")
|
||||||
|
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url))
|
||||||
|
|
||||||
|
|
||||||
|
def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]:
|
||||||
|
"""Encode a single media unit as an inline base64 part; returns (part, base64_length)."""
|
||||||
|
if kind == "image":
|
||||||
|
data = tensor_to_base64_string(payload, mime_type="image/webp")
|
||||||
|
mime = GeminiMimeType.image_webp
|
||||||
|
elif kind == "audio":
|
||||||
|
data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame")
|
||||||
|
mime = GeminiMimeType.audio_mp3
|
||||||
|
else:
|
||||||
|
data = video_to_base64_string(
|
||||||
|
payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||||
|
)
|
||||||
|
mime = GeminiMimeType.video_mp4
|
||||||
|
return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data)
|
||||||
|
|
||||||
|
|
||||||
|
async def build_gemini_media_parts(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
images: list[Input.Image],
|
||||||
|
audios: list[Input.Audio],
|
||||||
|
videos: list[Input.Video],
|
||||||
|
*,
|
||||||
|
url_budget: int = GEMINI_URL_INPUT_BUDGET,
|
||||||
|
max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES,
|
||||||
|
) -> list[GeminiPart]:
|
||||||
|
"""Build Gemini parts for multimodal inputs (images, audio, video).
|
||||||
|
|
||||||
|
fileData URLs are preferred for every media type: the upload is fetched directly by the
|
||||||
|
model, keeping the request body tiny regardless of media size. The URL budget is shared
|
||||||
|
across all media and assigned largest-first (video, then audio, then images), so that if it
|
||||||
|
is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline
|
||||||
|
payload is capped by `max_inline_bytes`.
|
||||||
|
"""
|
||||||
|
units: list[tuple[str, Any]] = (
|
||||||
|
[("video", v) for v in videos]
|
||||||
|
+ [("audio", a) for a in _flatten_audio(audios)]
|
||||||
|
+ [("image", f) for f in _flatten_images(images)]
|
||||||
|
)
|
||||||
|
|
||||||
|
parts: list[GeminiPart] = []
|
||||||
|
url_used = 0
|
||||||
|
inline_bytes = 0
|
||||||
|
for kind, payload in units:
|
||||||
|
if url_used < url_budget:
|
||||||
|
parts.append(await _media_url_part(cls, kind, payload))
|
||||||
|
url_used += 1
|
||||||
|
continue
|
||||||
|
part, nbytes = _media_inline_part(kind, payload)
|
||||||
|
inline_bytes += nbytes
|
||||||
|
if inline_bytes > max_inline_bytes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first "
|
||||||
|
f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media."
|
||||||
|
)
|
||||||
|
parts.append(part)
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
class GeminiNode(IO.ComfyNode):
|
class GeminiNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Node to generate text responses from a Gemini model.
|
Node to generate text responses from a Gemini model.
|
||||||
@ -407,58 +547,9 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
|
||||||
"""Convert video input to Gemini API compatible parts."""
|
|
||||||
|
|
||||||
base_64_string = video_to_base64_string(
|
|
||||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
GeminiPart(
|
|
||||||
inlineData=GeminiInlineData(
|
|
||||||
mimeType=GeminiMimeType.video_mp4,
|
|
||||||
data=base_64_string,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
|
|
||||||
"""
|
|
||||||
Convert audio input to Gemini API compatible parts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of GeminiPart objects containing the encoded audio.
|
|
||||||
"""
|
|
||||||
audio_parts: list[GeminiPart] = []
|
|
||||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
|
||||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
|
||||||
audio_at_index = Input.Audio(
|
|
||||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
|
||||||
sample_rate=audio_input["sample_rate"],
|
|
||||||
)
|
|
||||||
# Convert to MP3 format for compatibility with Gemini API
|
|
||||||
audio_bytes = audio_to_base64_string(
|
|
||||||
audio_at_index,
|
|
||||||
container_format="mp3",
|
|
||||||
codec_name="libmp3lame",
|
|
||||||
)
|
|
||||||
audio_parts.append(
|
|
||||||
GeminiPart(
|
|
||||||
inlineData=GeminiInlineData(
|
|
||||||
mimeType=GeminiMimeType.audio_mp3,
|
|
||||||
data=audio_bytes,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return audio_parts
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
@ -482,9 +573,9 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
if images is not None:
|
if images is not None:
|
||||||
parts.extend(await create_image_parts(cls, images))
|
parts.extend(await create_image_parts(cls, images))
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
parts.extend(cls.create_audio_parts(audio))
|
parts.extend(create_audio_parts(audio))
|
||||||
if video is not None:
|
if video is not None:
|
||||||
parts.extend(cls.create_video_parts(video))
|
parts.extend(create_video_parts(video))
|
||||||
if files is not None:
|
if files is not None:
|
||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
@ -512,6 +603,210 @@ class GeminiNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||||
|
|
||||||
|
|
||||||
|
GEMINI_V2_MODELS: dict[str, str] = {
|
||||||
|
"Gemini 3.1 Pro": "gemini-3.1-pro-preview",
|
||||||
|
"Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _gemini_text_model_inputs(thinking_default: str) -> list[Input]:
|
||||||
|
"""Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls)."""
|
||||||
|
return [
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"images",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
names=[f"image_{i}" for i in range(1, 17)],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
tooltip="Optional image(s) to use as context for the model. Up to 16 images.",
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"audio",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Audio.Input("audio"),
|
||||||
|
names=["audio_1"],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
tooltip="Optional audio clip to use as context for the model.",
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"video",
|
||||||
|
template=IO.Autogrow.TemplateNames(
|
||||||
|
IO.Video.Input("video"),
|
||||||
|
names=["video_1"],
|
||||||
|
min=0,
|
||||||
|
),
|
||||||
|
tooltip="Optional video clip to use as context for the model.",
|
||||||
|
),
|
||||||
|
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||||
|
"files",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional file(s) to use as context for the model. "
|
||||||
|
"Accepts inputs from the Gemini Input Files node.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"thinking_level",
|
||||||
|
options=["LOW", "HIGH"],
|
||||||
|
default=thinking_default,
|
||||||
|
tooltip="How hard the model reasons internally before answering. "
|
||||||
|
"HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"temperature",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"top_p",
|
||||||
|
default=0.95,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"max_output_tokens",
|
||||||
|
default=32768,
|
||||||
|
min=16,
|
||||||
|
max=65536,
|
||||||
|
tooltip="Maximum tokens to generate, including the model's internal thinking. "
|
||||||
|
"With thinking_level HIGH, a low value can leave no room for the answer; raise this if "
|
||||||
|
"responses come back empty or truncated. The model stops early when finished, so a higher "
|
||||||
|
"cap costs nothing extra for short replies.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiNodeV2(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GeminiNodeV2",
|
||||||
|
display_name="Google Gemini",
|
||||||
|
category="partner/text/Gemini",
|
||||||
|
essentials_category="Text Generation",
|
||||||
|
description="Generate text responses with Google's Gemini models. Provide a text prompt and, "
|
||||||
|
"optionally, one or more images, audio clips, videos, or files as multimodal context.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text input to the model. Include detailed instructions, questions, or context.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")),
|
||||||
|
IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")),
|
||||||
|
],
|
||||||
|
tooltip="The Gemini model used to generate the response.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=42,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"system_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
optional=True,
|
||||||
|
advanced=True,
|
||||||
|
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.String.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$m := widgets.model;
|
||||||
|
$contains($m, "lite") ? {
|
||||||
|
"type": "list_usd",
|
||||||
|
"usd": [0.00025, 0.0015],
|
||||||
|
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||||
|
} : {
|
||||||
|
"type": "list_usd",
|
||||||
|
"usd": [0.002, 0.012],
|
||||||
|
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||||
|
}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
system_prompt: str = "",
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
model_id = GEMINI_V2_MODELS[model["model"]]
|
||||||
|
|
||||||
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
|
images = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||||
|
audios = [a for a in (model.get("audio") or {}).values() if a is not None]
|
||||||
|
videos = [v for v in (model.get("video") or {}).values() if v is not None]
|
||||||
|
if images or audios or videos:
|
||||||
|
parts.extend(await build_gemini_media_parts(cls, images, audios, videos))
|
||||||
|
files = model.get("files")
|
||||||
|
if files is not None:
|
||||||
|
parts.extend(files)
|
||||||
|
|
||||||
|
gemini_system_prompt = None
|
||||||
|
if system_prompt:
|
||||||
|
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
|
||||||
|
data=GeminiGenerateContentRequest(
|
||||||
|
contents=[
|
||||||
|
GeminiContent(
|
||||||
|
role=GeminiRole.user,
|
||||||
|
parts=parts,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
generationConfig=GeminiGenerationConfig(
|
||||||
|
temperature=model["temperature"],
|
||||||
|
topP=model["top_p"],
|
||||||
|
maxOutputTokens=model["max_output_tokens"],
|
||||||
|
seed=seed if seed > 0 else None,
|
||||||
|
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||||
|
),
|
||||||
|
systemInstruction=gemini_system_prompt,
|
||||||
|
),
|
||||||
|
response_model=GeminiGenerateContentResponse,
|
||||||
|
price_extractor=calculate_tokens_price,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_text = get_text_from_response(response)
|
||||||
|
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||||
|
|
||||||
|
|
||||||
class GeminiInputFiles(IO.ComfyNode):
|
class GeminiInputFiles(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Loads and formats input files for use with the Gemini API.
|
Loads and formats input files for use with the Gemini API.
|
||||||
@ -1129,6 +1424,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
|||||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||||
advanced=True,
|
advanced=True,
|
||||||
),
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"temperature",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.01,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Controls randomness in generation. Lower is more focused/deterministic.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"top_p",
|
||||||
|
default=0.95,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -1165,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
response_modalities: str,
|
response_modalities: str,
|
||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 0.95,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
model_choice = model["model"]
|
model_choice = model["model"]
|
||||||
@ -1204,6 +1521,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
|||||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||||
imageConfig=image_config,
|
imageConfig=image_config,
|
||||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||||
|
temperature=temperature,
|
||||||
|
topP=top_p,
|
||||||
),
|
),
|
||||||
systemInstruction=gemini_system_prompt,
|
systemInstruction=gemini_system_prompt,
|
||||||
),
|
),
|
||||||
@ -1222,6 +1541,7 @@ class GeminiExtension(ComfyExtension):
|
|||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
GeminiNode,
|
GeminiNode,
|
||||||
|
GeminiNodeV2,
|
||||||
GeminiImage,
|
GeminiImage,
|
||||||
GeminiImage2,
|
GeminiImage2,
|
||||||
GeminiNanoBanana2,
|
GeminiNanoBanana2,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from comfy_api_nodes.apis.ideogram import (
|
|||||||
ImageRequest,
|
ImageRequest,
|
||||||
IdeogramV3Request,
|
IdeogramV3Request,
|
||||||
IdeogramV3EditRequest,
|
IdeogramV3EditRequest,
|
||||||
|
IdeogramV4Request,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -17,6 +18,7 @@ from comfy_api_nodes.util import (
|
|||||||
download_url_as_bytesio,
|
download_url_as_bytesio,
|
||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
sync_op,
|
sync_op,
|
||||||
|
validate_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
V1_V1_RES_MAP = {
|
V1_V1_RES_MAP = {
|
||||||
@ -798,6 +800,119 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
|
class IdeogramV4(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="IdeogramV4",
|
||||||
|
display_name="Ideogram V4",
|
||||||
|
category="partner/image/Ideogram",
|
||||||
|
description="Generates images using the Ideogram 4.0 model from a text prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt for the image generation.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[
|
||||||
|
"Auto",
|
||||||
|
"2048x2048 (1:1)",
|
||||||
|
"1440x2880 (1:2)",
|
||||||
|
"2880x1440 (2:1)",
|
||||||
|
"1664x2496 (2:3)",
|
||||||
|
"2496x1664 (3:2)",
|
||||||
|
"1792x2240 (4:5)",
|
||||||
|
"2240x1792 (5:4)",
|
||||||
|
"1440x2560 (9:16)",
|
||||||
|
"2560x1440 (16:9)",
|
||||||
|
"1600x2560 (5:8)",
|
||||||
|
"2560x1600 (8:5)",
|
||||||
|
"1728x2304 (3:4)",
|
||||||
|
"2304x1728 (4:3)",
|
||||||
|
"1296x3168 (9:22)",
|
||||||
|
"3168x1296 (22:9)",
|
||||||
|
"1152x2944 (9:23)",
|
||||||
|
"2944x1152 (23:9)",
|
||||||
|
"1248x3328 (3:8)",
|
||||||
|
"3328x1248 (8:3)",
|
||||||
|
"1280x3072 (5:12)",
|
||||||
|
"3072x1280 (12:5)",
|
||||||
|
],
|
||||||
|
default="Auto",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"rendering_speed",
|
||||||
|
options=["DEFAULT", "TURBO", "QUALITY"],
|
||||||
|
default="DEFAULT",
|
||||||
|
tooltip="Controls the trade-off between generation speed and quality.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
control_after_generate=True,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$speed := widgets.rendering_speed;
|
||||||
|
$price :=
|
||||||
|
$contains($speed,"turbo") ? 0.0429 :
|
||||||
|
$contains($speed,"quality") ? 0.143 :
|
||||||
|
0.0858;
|
||||||
|
{"type":"usd","usd": $price}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
resolution: str,
|
||||||
|
rendering_speed: str,
|
||||||
|
seed: int,
|
||||||
|
):
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/ideogram/ideogram-v4/generate", method="POST"),
|
||||||
|
response_model=IdeogramGenerateResponse,
|
||||||
|
data=IdeogramV4Request(
|
||||||
|
text_prompt=prompt,
|
||||||
|
resolution=resolution.split(" ")[0] if resolution != "Auto" else None,
|
||||||
|
rendering_speed=rendering_speed,
|
||||||
|
),
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.data or len(response.data) == 0:
|
||||||
|
raise Exception("No images were generated in the response")
|
||||||
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
if not image_urls:
|
||||||
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
class IdeogramExtension(ComfyExtension):
|
class IdeogramExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -805,6 +920,7 @@ class IdeogramExtension(ComfyExtension):
|
|||||||
IdeogramV1,
|
IdeogramV1,
|
||||||
IdeogramV2,
|
IdeogramV2,
|
||||||
IdeogramV3,
|
IdeogramV3,
|
||||||
|
IdeogramV4,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,9 +42,11 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima
|
|||||||
|
|
||||||
|
|
||||||
_MODEL_MEDIUM = "Krea 2 Medium"
|
_MODEL_MEDIUM = "Krea 2 Medium"
|
||||||
|
_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo"
|
||||||
_MODEL_LARGE = "Krea 2 Large"
|
_MODEL_LARGE = "Krea 2 Large"
|
||||||
_MODEL_ENDPOINTS: dict[str, str] = {
|
_MODEL_ENDPOINTS: dict[str, str] = {
|
||||||
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
|
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
|
||||||
|
_MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo",
|
||||||
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
|
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,7 +59,7 @@ _UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F
|
|||||||
|
|
||||||
|
|
||||||
def _krea_model_inputs() -> list:
|
def _krea_model_inputs() -> list:
|
||||||
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
|
"""Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo."""
|
||||||
return [
|
return [
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
@ -123,6 +125,7 @@ class Krea2ImageNode(IO.ComfyNode):
|
|||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
|
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
|
||||||
|
IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()),
|
||||||
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
|
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
|
||||||
],
|
],
|
||||||
tooltip="Krea 2 Medium is best for expressive illustrations; "
|
tooltip="Krea 2 Medium is best for expressive illustrations; "
|
||||||
@ -151,14 +154,15 @@ class Krea2ImageNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$isLarge := widgets.model = "krea 2 large";
|
$rates := {
|
||||||
|
"krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02},
|
||||||
|
"krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04},
|
||||||
|
"krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07}
|
||||||
|
};
|
||||||
|
$r := $lookup($rates, widgets.model);
|
||||||
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
|
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
|
||||||
$hasStyle := $lookup(inputs, "model.style_reference").connected;
|
$hasStyle := $lookup(inputs, "model.style_reference").connected;
|
||||||
$usd := $hasMoodboard
|
$usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text);
|
||||||
? ($isLarge ? 0.07 : 0.04)
|
|
||||||
: ($hasStyle
|
|
||||||
? ($isLarge ? 0.065 : 0.035)
|
|
||||||
: ($isLarge ? 0.06 : 0.03));
|
|
||||||
{"type":"usd","usd": $usd}
|
{"type":"usd","usd": $usd}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="SaveAudio",
|
node_id="SaveAudio",
|
||||||
search_aliases=["export flac"],
|
search_aliases=["export flac"],
|
||||||
display_name="Save Audio (FLAC)",
|
display_name="Save Audio (FLAC) (Deprecated)",
|
||||||
category="audio",
|
category="audio",
|
||||||
essentials_category="Audio",
|
essentials_category="Audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -167,6 +167,7 @@ class SaveAudio(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -186,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="SaveAudioMP3",
|
node_id="SaveAudioMP3",
|
||||||
search_aliases=["export mp3"],
|
search_aliases=["export mp3"],
|
||||||
display_name="Save Audio (MP3)",
|
display_name="Save Audio (MP3) (Deprecated)",
|
||||||
category="audio",
|
category="audio",
|
||||||
essentials_category="Audio",
|
essentials_category="Audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -196,6 +197,7 @@ class SaveAudioMP3(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -217,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="SaveAudioOpus",
|
node_id="SaveAudioOpus",
|
||||||
search_aliases=["export opus"],
|
search_aliases=["export opus"],
|
||||||
display_name="Save Audio (Opus)",
|
display_name="Save Audio (Opus) (Deprecated)",
|
||||||
category="audio",
|
category="audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Audio.Input("audio"),
|
IO.Audio.Input("audio"),
|
||||||
@ -226,6 +228,7 @@ class SaveAudioOpus(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -241,6 +244,54 @@ class SaveAudioOpus(IO.ComfyNode):
|
|||||||
save_opus = execute # TODO: remove
|
save_opus = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class SaveAudioAdvanced(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="SaveAudioAdvanced",
|
||||||
|
search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"],
|
||||||
|
display_name="Save Audio (Advanced)",
|
||||||
|
description="Saves the input audio to your ComfyUI output directory.",
|
||||||
|
category="audio",
|
||||||
|
inputs=[
|
||||||
|
IO.Audio.Input("audio", tooltip="The audio to save."),
|
||||||
|
IO.String.Input(
|
||||||
|
"filename_prefix",
|
||||||
|
default="audio/ComfyUI",
|
||||||
|
tooltip=(
|
||||||
|
"The prefix for the file to save. May include formatting tokens "
|
||||||
|
"such as %date:yyyy-MM-dd%."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"format",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("flac", []),
|
||||||
|
IO.DynamicCombo.Option("mp3", [
|
||||||
|
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||||
|
]),
|
||||||
|
IO.DynamicCombo.Option("opus", [
|
||||||
|
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||||
|
]),
|
||||||
|
],
|
||||||
|
tooltip="The file format in which to save the audio.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput:
|
||||||
|
file_format = format.get("format", None)
|
||||||
|
quality = format.get("quality", None)
|
||||||
|
if quality:
|
||||||
|
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality)
|
||||||
|
else:
|
||||||
|
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format)
|
||||||
|
return IO.NodeOutput(ui=ui)
|
||||||
|
|
||||||
|
|
||||||
class PreviewAudio(IO.ComfyNode):
|
class PreviewAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -822,6 +873,7 @@ class AudioExtension(ComfyExtension):
|
|||||||
SaveAudio,
|
SaveAudio,
|
||||||
SaveAudioMP3,
|
SaveAudioMP3,
|
||||||
SaveAudioOpus,
|
SaveAudioOpus,
|
||||||
|
SaveAudioAdvanced,
|
||||||
LoadAudio,
|
LoadAudio,
|
||||||
PreviewAudio,
|
PreviewAudio,
|
||||||
ConditioningStableAudio,
|
ConditioningStableAudio,
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
import comfy.sampler_helpers
|
||||||
|
import comfy.patcher_extension
|
||||||
import comfy.sample
|
import comfy.sample
|
||||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||||
from comfy.k_diffusion import sa_solver
|
from comfy.k_diffusion import sa_solver
|
||||||
@ -894,6 +896,85 @@ class DualCFGGuider(io.ComfyNode):
|
|||||||
|
|
||||||
get_guider = execute
|
get_guider = execute
|
||||||
|
|
||||||
|
class Guider_DualModel(comfy.samplers.CFGGuider):
|
||||||
|
# Runs the positive (cond) pass on the main model and the negative (uncond) pass on a separate model
|
||||||
|
def __init__(self, model_patcher, uncond_model_patcher):
|
||||||
|
super().__init__(model_patcher)
|
||||||
|
self.uncond_model_patcher = uncond_model_patcher
|
||||||
|
self.uncond_inner = None
|
||||||
|
|
||||||
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
|
||||||
|
self.uncond_inner = None
|
||||||
|
self.uncond_loaded = []
|
||||||
|
self._uncond_neg = None
|
||||||
|
# skip at cfg 1.0
|
||||||
|
if not math.isclose(self.cfg, 1.0):
|
||||||
|
uc = {"negative": list(map(lambda a: a.copy(), self.conds["negative"]))}
|
||||||
|
self.uncond_inner, uc, self.uncond_loaded = comfy.sampler_helpers.prepare_sampling(
|
||||||
|
self.uncond_model_patcher, noise.shape, uc, self.uncond_model_patcher.model_options)
|
||||||
|
self._uncond_neg = uc["negative"]
|
||||||
|
self.uncond_model_patcher.pre_run()
|
||||||
|
try:
|
||||||
|
return super().outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
|
finally:
|
||||||
|
if self.uncond_inner is not None:
|
||||||
|
self.uncond_model_patcher.cleanup()
|
||||||
|
comfy.sampler_helpers.cleanup_models({"negative": self._uncond_neg}, self.uncond_loaded)
|
||||||
|
self.uncond_inner = None
|
||||||
|
|
||||||
|
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
|
||||||
|
if self.uncond_inner is not None:
|
||||||
|
li = latent_image
|
||||||
|
if li is not None and torch.count_nonzero(li) > 0:
|
||||||
|
li = self.uncond_inner.process_latent_in(li)
|
||||||
|
self._uncond_conds = comfy.samplers.process_conds(
|
||||||
|
self.uncond_inner, noise, {"negative": self._uncond_neg}, device, li, denoise_mask, seed, latent_shapes=latent_shapes)["negative"]
|
||||||
|
return super().inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
|
|
||||||
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
|
positive = self.conds.get("positive", None)
|
||||||
|
cond = comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0]
|
||||||
|
# uncond model not loaded (base cfg==1/no negative), or cfg driven to 1.0 this step -> single model, cond only
|
||||||
|
if self.uncond_inner is None or (math.isclose(self.cfg, 1.0) and not model_options.get("disable_cfg1_optimization", False)):
|
||||||
|
return cond
|
||||||
|
|
||||||
|
uncond_model_options = model_options
|
||||||
|
if "multigpu_clones" in model_options: # TODO: support multigpu instead of just running uncond on a single GPU
|
||||||
|
uncond_model_options = {k: v for k, v in model_options.items() if k != "multigpu_clones"}
|
||||||
|
uncond = comfy.samplers.calc_cond_batch(self.uncond_inner, [self._uncond_conds], x, timestep, uncond_model_options)[0]
|
||||||
|
return comfy.samplers.cfg_function(self.inner_model, cond, uncond, self.cfg, x, timestep,
|
||||||
|
model_options=model_options, cond=positive, uncond=self._uncond_conds)
|
||||||
|
|
||||||
|
class DualModelGuider(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="DualModelGuider",
|
||||||
|
display_name="Dual Model CFG Guider",
|
||||||
|
category="model/sampling/guiders",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="Model used for the positive (conditional) pass."),
|
||||||
|
io.Model.Input("model_negative", optional=True, tooltip="Model used for the negative (unconditional) pass. Use the same model for ordinary CFG."),
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Float.Input("cfg", default=4.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||||
|
io.Conditioning.Input("negative", optional=True, tooltip="Negative conditioning run on the negative model. Leave unconnected for a text-free (image-only) unconditional pass."),
|
||||||
|
],
|
||||||
|
outputs=[io.Guider.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, positive, cfg, model_negative=None, negative=None) -> io.NodeOutput:
|
||||||
|
if negative is None:
|
||||||
|
negative = [[None, {}]] # null cond -> no cross_attn -> model runs image-only
|
||||||
|
|
||||||
|
guider = Guider_DualModel(model, model_negative) if model_negative is not None else comfy.samplers.CFGGuider(model)
|
||||||
|
guider.set_conds(positive, negative)
|
||||||
|
guider.set_cfg(cfg)
|
||||||
|
return io.NodeOutput(guider)
|
||||||
|
|
||||||
|
get_guider = execute
|
||||||
|
|
||||||
class DisableNoise(io.ComfyNode):
|
class DisableNoise(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -1054,11 +1135,53 @@ class ManualSigmas(io.ComfyNode):
|
|||||||
sigmas = torch.FloatTensor(sigmas)
|
sigmas = torch.FloatTensor(sigmas)
|
||||||
return io.NodeOutput(sigmas)
|
return io.NodeOutput(sigmas)
|
||||||
|
|
||||||
|
class CFGOverride(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CFGOverride",
|
||||||
|
display_name="CFG Override",
|
||||||
|
description="Override cfg to a fixed value over a [start, end] percent (sigma) range. "
|
||||||
|
"With multiple overrides, the one nearest the sampler wins on overlap.",
|
||||||
|
category="sampling/custom_sampling",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, cfg, start_percent, end_percent) -> io.NodeOutput:
|
||||||
|
ms = model.get_model_object("model_sampling")
|
||||||
|
sigma_hi = ms.percent_to_sigma(start_percent) # percent->sigma decreasing, so hi >= lo
|
||||||
|
sigma_lo = ms.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def predict_noise_wrapper(executor, *args, **kwargs):
|
||||||
|
sigma = float(args[1].flatten()[0]) # args = (x, timestep, model_options, seed)
|
||||||
|
if not (sigma_lo <= sigma <= sigma_hi):
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
guider = executor.class_obj # guider.cfg feeds cond_scale
|
||||||
|
saved = guider.cfg
|
||||||
|
guider.cfg = cfg
|
||||||
|
try:
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
guider.cfg = saved # restore for other steps/overrides
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.add_wrapper(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, predict_noise_wrapper)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
class CustomSamplersExtension(ComfyExtension):
|
class CustomSamplersExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
SamplerCustom,
|
SamplerCustom,
|
||||||
|
CFGOverride,
|
||||||
BasicScheduler,
|
BasicScheduler,
|
||||||
KarrasScheduler,
|
KarrasScheduler,
|
||||||
ExponentialScheduler,
|
ExponentialScheduler,
|
||||||
@ -1087,6 +1210,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
SamplingPercentToSigma,
|
SamplingPercentToSigma,
|
||||||
CFGGuider,
|
CFGGuider,
|
||||||
DualCFGGuider,
|
DualCFGGuider,
|
||||||
|
DualModelGuider,
|
||||||
BasicGuider,
|
BasicGuider,
|
||||||
RandomNoise,
|
RandomNoise,
|
||||||
DisableNoise,
|
DisableNoise,
|
||||||
|
|||||||
@ -411,6 +411,21 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
|
|
||||||
return has_group
|
return has_group
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _ensure_image_list(cls, images):
|
||||||
|
"""Normalize to a flat list of [1, H, W, C] tensors."""
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
if images.ndim != 4:
|
||||||
|
raise ValueError(f"Expected 4D image tensor, got shape {tuple(images.shape)}")
|
||||||
|
return [images[i:i+1] for i in range(images.shape[0])]
|
||||||
|
|
||||||
|
flat = []
|
||||||
|
for item in images:
|
||||||
|
if not isinstance(item, torch.Tensor) or item.ndim != 4:
|
||||||
|
raise ValueError(f"Expected 4D image tensor, got {type(item).__name__} shape {getattr(item, 'shape', None)}")
|
||||||
|
flat.extend([item[i:i+1] for i in range(item.shape[0])])
|
||||||
|
return flat
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
if cls.node_id is None:
|
if cls.node_id is None:
|
||||||
@ -458,6 +473,9 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
"""Execute the node. Routes to _process or _group_process based on mode."""
|
"""Execute the node. Routes to _process or _group_process based on mode."""
|
||||||
is_group = cls._detect_processing_mode()
|
is_group = cls._detect_processing_mode()
|
||||||
|
|
||||||
|
if is_group:
|
||||||
|
images = cls._ensure_image_list(images)
|
||||||
|
|
||||||
# Extract scalar values from lists for parameters
|
# Extract scalar values from lists for parameters
|
||||||
params = {}
|
params = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
|
|||||||
@ -488,7 +488,7 @@ class SplatToFile3D(IO.ComfyNode):
|
|||||||
"spz: Niantic gzip-compressed (~10x smaller), base color only "
|
"spz: Niantic gzip-compressed (~10x smaller), base color only "
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[IO.File3DAny.Output(display_name="model_3d")],
|
outputs=[IO.File3DSplatAny.Output(display_name="model_3d")],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -516,7 +516,7 @@ class File3DToSplat(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.MultiType.Input(
|
IO.MultiType.Input(
|
||||||
IO.File3DAny.Input("model_3d"),
|
IO.File3DAny.Input("model_3d"),
|
||||||
types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ],
|
types=[IO.File3DSplatAny, IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ],
|
||||||
tooltip="A gaussian splat 3D file",
|
tooltip="A gaussian splat 3D file",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
|||||||
64
comfy_extras/nodes_ideogram4.py
Normal file
64
comfy_extras/nodes_ideogram4.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""Ideogram 4 sampling helper
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
_LOGSNR_MIN = -15.0
|
||||||
|
_LOGSNR_MAX = 18.0
|
||||||
|
|
||||||
|
|
||||||
|
def _logit_normal_schedule(u, mean, std):
|
||||||
|
# Reference time (0=noise..1=clean) via the probit/ndtri quantile.
|
||||||
|
u = torch.as_tensor(u, dtype=torch.float64)
|
||||||
|
t = 1.0 - torch.special.expit(mean + std * torch.special.ndtri(u))
|
||||||
|
t_min = 1.0 / (1.0 + math.exp(0.5 * _LOGSNR_MAX))
|
||||||
|
t_max = 1.0 / (1.0 + math.exp(0.5 * _LOGSNR_MIN))
|
||||||
|
return t.clamp(t_min, t_max)
|
||||||
|
|
||||||
|
|
||||||
|
def ideogram4_sigmas(num_steps, width, height, mu, std):
|
||||||
|
"""Descending sigmas (len num_steps+1) for the reference schedule.
|
||||||
|
|
||||||
|
mu + the resolution term form the logSNR shift; std is the spread.
|
||||||
|
"""
|
||||||
|
mean = mu + 0.5 * math.log((width * height) / (512 * 512))
|
||||||
|
u = torch.linspace(0.0, 1.0, num_steps + 1, dtype=torch.float64)
|
||||||
|
sigmas = (1.0 - _logit_normal_schedule(u, mean, std)).flip(0)
|
||||||
|
sigmas[-1] = 0.0 # clamp leaves ~6e-4; force full denoise
|
||||||
|
return sigmas.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4Scheduler(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Ideogram4Scheduler",
|
||||||
|
display_name="Ideogram 4 Scheduler",
|
||||||
|
category="sampling/custom_sampling/schedulers",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("steps", default=20, min=1, max=200),
|
||||||
|
io.Int.Input("width", default=1024, min=256, max=8192, step=16),
|
||||||
|
io.Int.Input("height", default=1024, min=256, max=8192, step=16),
|
||||||
|
io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.05),
|
||||||
|
io.Float.Input("std", default=1.75, min=0.1, max=5.0, step=0.05),
|
||||||
|
],
|
||||||
|
outputs=[io.Sigmas.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, steps, width, height, mu, std) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(ideogram4_sigmas(steps, width, height, mu, std))
|
||||||
|
|
||||||
|
|
||||||
|
class Ideogram4Extension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [Ideogram4Scheduler]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> Ideogram4Extension:
|
||||||
|
return Ideogram4Extension()
|
||||||
@ -51,6 +51,14 @@ class Load3D(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
|
||||||
|
if not model_file or model_file == "none":
|
||||||
|
return True
|
||||||
|
if not folder_paths.exists_annotated_filepath(model_file):
|
||||||
|
return f"Invalid 3D model file: {model_file}"
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
|
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
|
||||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||||
@ -136,7 +144,7 @@ class Preview3DAdvanced(IO.ComfyNode):
|
|||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.MultiType.Input(
|
IO.MultiType.Input(
|
||||||
"model_file",
|
"model_3d",
|
||||||
types=[
|
types=[
|
||||||
IO.File3DGLB,
|
IO.File3DGLB,
|
||||||
IO.File3DGLTF,
|
IO.File3DGLTF,
|
||||||
@ -148,34 +156,161 @@ class Preview3DAdvanced(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
tooltip="3D model file from an upstream 3D node.",
|
tooltip="3D model file from an upstream 3D node.",
|
||||||
),
|
),
|
||||||
IO.Load3D.Input("image"),
|
|
||||||
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
|
||||||
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
|
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
|
||||||
|
IO.Load3D.Input("viewport_state"),
|
||||||
|
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.File3DAny.Output(display_name="model_file"),
|
IO.File3DAny.Output(display_name="model_3d"),
|
||||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
|
||||||
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||||
|
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||||
IO.Int.Output(display_name="width"),
|
IO.Int.Output(display_name="width"),
|
||||||
IO.Int.Output(display_name="height"),
|
IO.Int.Output(display_name="height"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model_file: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||||
filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_file.format}"
|
filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_3d.format}"
|
||||||
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
|
||||||
|
|
||||||
camera_info_input = kwargs.get("camera_info", None)
|
camera_info_input = kwargs.get("camera_info", None)
|
||||||
camera_info = camera_info_input if camera_info_input is not None else image['camera_info']
|
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
|
||||||
model_3d_info_input = kwargs.get("model_3d_info", None)
|
model_3d_info_input = kwargs.get("model_3d_info", None)
|
||||||
model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', [])
|
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
model_file,
|
model_3d,
|
||||||
camera_info,
|
|
||||||
model_3d_info,
|
model_3d_info,
|
||||||
|
camera_info,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewGaussianSplat(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="PreviewGaussianSplat",
|
||||||
|
display_name="Preview Splat",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
search_aliases=[
|
||||||
|
"view splat",
|
||||||
|
"view gaussian",
|
||||||
|
"view gaussian splat",
|
||||||
|
"preview gaussian",
|
||||||
|
"preview gaussian splat",
|
||||||
|
"view 3dgs",
|
||||||
|
"preview 3dgs",
|
||||||
|
"preview ply",
|
||||||
|
"preview spz",
|
||||||
|
"preview splat",
|
||||||
|
"preview ksplat",
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
IO.MultiType.Input(
|
||||||
|
"model_3d",
|
||||||
|
types=[
|
||||||
|
IO.File3DSplatAny,
|
||||||
|
IO.File3DPLY,
|
||||||
|
IO.File3DSPLAT,
|
||||||
|
IO.File3DSPZ,
|
||||||
|
IO.File3DKSPLAT,
|
||||||
|
],
|
||||||
|
tooltip="A gaussian splat 3D file.",
|
||||||
|
),
|
||||||
|
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
|
||||||
|
IO.Load3D.Input("viewport_state"),
|
||||||
|
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||||
|
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||||
|
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.File3DSplatAny.Output(display_name="model_3d"),
|
||||||
|
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||||
|
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||||
|
IO.Int.Output(display_name="width"),
|
||||||
|
IO.Int.Output(display_name="height"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||||
|
filename = f"preview_splat_{uuid.uuid4().hex}.{model_3d.format}"
|
||||||
|
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
|
||||||
|
|
||||||
|
camera_info_input = kwargs.get("camera_info", None)
|
||||||
|
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
|
||||||
|
model_3d_info_input = kwargs.get("model_3d_info", None)
|
||||||
|
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
|
||||||
|
return IO.NodeOutput(
|
||||||
|
model_3d,
|
||||||
|
model_3d_info,
|
||||||
|
camera_info,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewPointCloud(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="PreviewPointCloud",
|
||||||
|
display_name="Preview Point Cloud",
|
||||||
|
category="3d",
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
search_aliases=[
|
||||||
|
"view point cloud",
|
||||||
|
"view pointcloud",
|
||||||
|
"preview point cloud",
|
||||||
|
"preview pointcloud",
|
||||||
|
"preview ply",
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
IO.MultiType.Input(
|
||||||
|
"model_3d",
|
||||||
|
types=[
|
||||||
|
IO.File3DPointCloudAny,
|
||||||
|
IO.File3DPLY,
|
||||||
|
],
|
||||||
|
tooltip="Point cloud file (.ply)",
|
||||||
|
),
|
||||||
|
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
|
||||||
|
IO.Load3D.Input("viewport_state"),
|
||||||
|
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||||
|
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||||
|
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.File3DPointCloudAny.Output(display_name="model_3d"),
|
||||||
|
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||||
|
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||||
|
IO.Int.Output(display_name="width"),
|
||||||
|
IO.Int.Output(display_name="height"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||||
|
filename = f"preview_pointcloud_{uuid.uuid4().hex}.{model_3d.format}"
|
||||||
|
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
|
||||||
|
|
||||||
|
camera_info_input = kwargs.get("camera_info", None)
|
||||||
|
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
|
||||||
|
model_3d_info_input = kwargs.get("model_3d_info", None)
|
||||||
|
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
|
||||||
|
return IO.NodeOutput(
|
||||||
|
model_3d,
|
||||||
|
model_3d_info,
|
||||||
|
camera_info,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
|
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
|
||||||
@ -189,6 +324,8 @@ class Load3DExtension(ComfyExtension):
|
|||||||
Load3D,
|
Load3D,
|
||||||
Preview3D,
|
Preview3D,
|
||||||
Preview3DAdvanced,
|
Preview3DAdvanced,
|
||||||
|
PreviewGaussianSplat,
|
||||||
|
PreviewPointCloud,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -337,6 +337,12 @@ class SaveGLB(IO.ComfyNode):
|
|||||||
IO.File3DFBX,
|
IO.File3DFBX,
|
||||||
IO.File3DSTL,
|
IO.File3DSTL,
|
||||||
IO.File3DUSDZ,
|
IO.File3DUSDZ,
|
||||||
|
IO.File3DPLY,
|
||||||
|
IO.File3DSPLAT,
|
||||||
|
IO.File3DSPZ,
|
||||||
|
IO.File3DKSPLAT,
|
||||||
|
IO.File3DSplatAny,
|
||||||
|
IO.File3DPointCloudAny,
|
||||||
IO.File3DAny,
|
IO.File3DAny,
|
||||||
],
|
],
|
||||||
tooltip="Mesh or 3D file to save",
|
tooltip="Mesh or 3D file to save",
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class SaveWEBM(io.ComfyNode):
|
|||||||
category="video",
|
category="video",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("images"),
|
io.Image.Input("images", tooltip="RGBA images are saved with their alpha channel as transparency (vp9 codec only)."),
|
||||||
io.String.Input("filename_prefix", default="ComfyUI"),
|
io.String.Input("filename_prefix", default="ComfyUI"),
|
||||||
io.Combo.Input("codec", options=["vp9", "av1"]),
|
io.Combo.Input("codec", options=["vp9", "av1"]),
|
||||||
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
||||||
@ -45,18 +45,25 @@ class SaveWEBM(io.ComfyNode):
|
|||||||
for x in cls.hidden.extra_pnginfo:
|
for x in cls.hidden.extra_pnginfo:
|
||||||
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||||
|
|
||||||
|
# Save transparency when the images carry an alpha channel (RGBA) and the codec supports it.
|
||||||
|
# vp9 -> yuva420p; other codecs have no usable alpha path, so the alpha is ignored.
|
||||||
|
save_alpha = images.shape[-1] == 4 and codec == "vp9"
|
||||||
|
|
||||||
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
|
||||||
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
|
||||||
stream.width = images.shape[-2]
|
stream.width = images.shape[-2]
|
||||||
stream.height = images.shape[-3]
|
stream.height = images.shape[-3]
|
||||||
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
|
stream.pix_fmt = "yuva420p" if save_alpha else ("yuv420p10le" if codec == "av1" else "yuv420p")
|
||||||
stream.bit_rate = 0
|
stream.bit_rate = 0
|
||||||
stream.options = {'crf': str(crf)}
|
stream.options = {'crf': str(crf)}
|
||||||
if codec == "av1":
|
if codec == "av1":
|
||||||
stream.options["preset"] = "6"
|
stream.options["preset"] = "6"
|
||||||
|
|
||||||
for frame in images:
|
for frame in images:
|
||||||
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
if save_alpha:
|
||||||
|
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :4] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgba")
|
||||||
|
else:
|
||||||
|
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
|
||||||
for packet in stream.encode(frame):
|
for packet in stream.encode(frame):
|
||||||
container.mux(packet)
|
container.mux(packet)
|
||||||
container.mux(stream.encode())
|
container.mux(stream.encode())
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.23.0"
|
__version__ = "0.24.0"
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -969,7 +969,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -2362,6 +2362,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_model_downscale.py",
|
"nodes_model_downscale.py",
|
||||||
"nodes_images.py",
|
"nodes_images.py",
|
||||||
"nodes_video_model.py",
|
"nodes_video_model.py",
|
||||||
|
"nodes_ideogram4.py",
|
||||||
"nodes_train.py",
|
"nodes_train.py",
|
||||||
"nodes_dataset.py",
|
"nodes_dataset.py",
|
||||||
"nodes_sag.py",
|
"nodes_sag.py",
|
||||||
|
|||||||
16661
openapi.yaml
16661
openapi.yaml
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.23.0"
|
version = "0.24.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.44.19
|
comfyui-frontend-package==1.45.15
|
||||||
comfyui-workflow-templates==0.9.92
|
comfyui-workflow-templates==0.9.98
|
||||||
comfyui-embedded-docs==0.5.2
|
comfyui-embedded-docs==0.5.2
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
|||||||
filelock
|
filelock
|
||||||
av>=16.0.0
|
av>=16.0.0
|
||||||
comfy-kitchen==0.2.10
|
comfy-kitchen==0.2.10
|
||||||
comfy-aimdo==0.4.8
|
comfy-aimdo==0.4.9
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user