mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 20:32:45 +08:00
Merge remote-tracking branch 'upstream/master' into gemma4
This commit is contained in:
commit
11162f9e74
@ -195,7 +195,9 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
|||||||
|
|
||||||
#### Alternative Downloads:
|
#### Alternative Downloads:
|
||||||
|
|
||||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||||
|
|
||||||
|
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class InternalRoutes:
|
|||||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||||
key=lambda entry: -entry.stat().st_mtime
|
key=lambda entry: -entry.stat().st_mtime
|
||||||
)
|
)
|
||||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200)
|
||||||
|
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
|
|||||||
@ -15,7 +15,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
|||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
|
||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
out = torch.einsum("...n,d->...nd", pos.to(device), omega)
|
||||||
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
@ -118,8 +118,6 @@ class ErnieImageAttention(nn.Module):
|
|||||||
query = apply_rotary_emb(query, image_rotary_emb)
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
key = apply_rotary_emb(key, image_rotary_emb)
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
query, key = query.to(x.dtype), key.to(x.dtype)
|
|
||||||
|
|
||||||
q_flat = query.reshape(B, S, -1)
|
q_flat = query.reshape(B, S, -1)
|
||||||
k_flat = key.reshape(B, S, -1)
|
k_flat = key.reshape(B, S, -1)
|
||||||
|
|
||||||
@ -161,16 +159,16 @@ class ErnieImageSharedAdaLNBlock(nn.Module):
|
|||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
x_norm = self.adaLN_sa_ln(x)
|
x_norm = self.adaLN_sa_ln(x)
|
||||||
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
x_norm = x_norm * (1 + scale_msa) + shift_msa
|
||||||
|
|
||||||
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||||
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
x = residual + gate_msa * attn_out
|
||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
x_norm = self.adaLN_mlp_ln(x)
|
x_norm = self.adaLN_mlp_ln(x)
|
||||||
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
x_norm = x_norm * (1 + scale_mlp) + shift_mlp
|
||||||
|
|
||||||
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
|
return residual + gate_mlp * self.mlp(x_norm)
|
||||||
|
|
||||||
class ErnieImageAdaLNContinuous(nn.Module):
|
class ErnieImageAdaLNContinuous(nn.Module):
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
|
||||||
@ -183,7 +181,7 @@ class ErnieImageAdaLNContinuous(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||||
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class ErnieImageModel(nn.Module):
|
class ErnieImageModel(nn.Module):
|
||||||
|
|||||||
@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
|
|||||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||||
for layer in ts:
|
for layer in ts:
|
||||||
|
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||||
|
found_patched = False
|
||||||
|
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||||
|
if isinstance(layer, class_type):
|
||||||
|
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||||
|
found_patched = True
|
||||||
|
break
|
||||||
|
if found_patched:
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(layer, VideoResBlock):
|
if isinstance(layer, VideoResBlock):
|
||||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
elif isinstance(layer, TimestepBlock):
|
elif isinstance(layer, TimestepBlock):
|
||||||
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
|||||||
elif isinstance(layer, Upsample):
|
elif isinstance(layer, Upsample):
|
||||||
x = layer(x, output_shape=output_shape)
|
x = layer(x, output_shape=output_shape)
|
||||||
else:
|
else:
|
||||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
|
||||||
found_patched = False
|
|
||||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
|
||||||
if isinstance(layer, class_type):
|
|
||||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
|
||||||
found_patched = True
|
|
||||||
break
|
|
||||||
if found_patched:
|
|
||||||
continue
|
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
|
|||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = apply_control(h, control, 'middle')
|
h = apply_control(h, control, 'middle')
|
||||||
|
|
||||||
|
if "middle_block_after_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["middle_block_after_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
|
||||||
|
"timesteps": timesteps, "transformer_options": transformer_options})
|
||||||
|
h = out["h"]
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
@ -905,6 +912,7 @@ class UNetModel(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
h, hsp = p(h, hsp, transformer_options)
|
h, hsp = p(h, hsp, transformer_options)
|
||||||
|
|
||||||
|
if hsp is not None:
|
||||||
h = th.cat([h, hsp], dim=1)
|
h = th.cat([h, hsp], dim=1)
|
||||||
del hsp
|
del hsp
|
||||||
if len(hs) > 0:
|
if len(hs) > 0:
|
||||||
|
|||||||
0
comfy/ldm/supir/__init__.py
Normal file
0
comfy/ldm/supir/__init__.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroSFT(nn.Module):
|
||||||
|
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
ks = 3
|
||||||
|
pw = ks // 2
|
||||||
|
|
||||||
|
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
nhidden = 128
|
||||||
|
|
||||||
|
self.mlp_shared = nn.Sequential(
|
||||||
|
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
|
||||||
|
self.pre_concat = bool(concat_channels != 0)
|
||||||
|
|
||||||
|
def forward(self, c, h, h_ori=None, control_scale=1):
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h_raw = torch.cat([h_ori, h], dim=1)
|
||||||
|
else:
|
||||||
|
h_raw = h
|
||||||
|
|
||||||
|
h = h + self.zero_conv(c)
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
actv = self.mlp_shared(c)
|
||||||
|
gamma = self.zero_mul(actv)
|
||||||
|
beta = self.zero_add(actv)
|
||||||
|
h = self.param_free_norm(h)
|
||||||
|
h = torch.addcmul(h + beta, h, gamma)
|
||||||
|
if h_ori is not None and not self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
return torch.lerp(h_raw, h, control_scale)
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossAttnInner(nn.Module):
|
||||||
|
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
|
||||||
|
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, context):
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
return self.to_out(optimized_attention(q, k, v, self.heads))
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroCrossAttn(nn.Module):
|
||||||
|
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
heads = query_dim // 64
|
||||||
|
dim_head = 64
|
||||||
|
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, context, x, control_scale=1):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
|
||||||
|
x = self.attn(
|
||||||
|
self.norm1(x).flatten(2).transpose(1, 2),
|
||||||
|
self.norm2(context).flatten(2).transpose(1, 2),
|
||||||
|
).transpose(1, 2).unflatten(2, (h, w))
|
||||||
|
|
||||||
|
return x_in + x * control_scale
|
||||||
|
|
||||||
|
|
||||||
|
class GLVControl(nn.Module):
|
||||||
|
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=4,
|
||||||
|
model_channels=320,
|
||||||
|
num_res_blocks=2,
|
||||||
|
attention_resolutions=(4, 2),
|
||||||
|
channel_mult=(1, 2, 4),
|
||||||
|
num_head_channels=64,
|
||||||
|
transformer_depth=(1, 2, 10),
|
||||||
|
context_dim=2048,
|
||||||
|
adm_in_channels=2816,
|
||||||
|
use_linear_in_transformer=True,
|
||||||
|
use_checkpoint=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_channels = model_channels
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
])
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for nr in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
layers.append(
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[level], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ds *= 2
|
||||||
|
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[-1], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations),
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
|
emb = self.time_embed(t_emb) + self.label_emb(y)
|
||||||
|
|
||||||
|
guided_hint = self.input_hint_block(x, emb, context)
|
||||||
|
|
||||||
|
hs = []
|
||||||
|
h = xt
|
||||||
|
for module in self.input_blocks:
|
||||||
|
if guided_hint is not None:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
h += guided_hint
|
||||||
|
guided_hint = None
|
||||||
|
else:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
h = self.middle_block(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIR(nn.Module):
|
||||||
|
"""
|
||||||
|
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
|
||||||
|
State dict keys match the original SUPIR checkpoint layout:
|
||||||
|
control_model.* -> GLVControl
|
||||||
|
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
"""
|
||||||
|
def __init__(self, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
project_channel_scale = 2
|
||||||
|
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
||||||
|
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
|
||||||
|
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
||||||
|
cross_attn_insert_idx = [6, 3]
|
||||||
|
|
||||||
|
self.project_modules = nn.ModuleList()
|
||||||
|
for i in range(len(cond_output_channels)):
|
||||||
|
self.project_modules.append(ZeroSFT(
|
||||||
|
project_channels[i], cond_output_channels[i],
|
||||||
|
concat_channels=concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
|
|
||||||
|
for i in cross_attn_insert_idx:
|
||||||
|
self.project_modules.insert(i, ZeroCrossAttn(
|
||||||
|
cond_output_channels[i], concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
103
comfy/ldm/supir/supir_patch.py
Normal file
103
comfy/ldm/supir/supir_patch.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRPatch:
|
||||||
|
"""
|
||||||
|
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
|
||||||
|
Runs GLVControl lazily on first patch invocation per step, applies adapters through
|
||||||
|
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
|
||||||
|
"""
|
||||||
|
SIGMA_MAX = 14.6146
|
||||||
|
|
||||||
|
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
|
||||||
|
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
|
||||||
|
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
self.hint_latent = hint_latent # encoded LQ image latent
|
||||||
|
self.strength_start = strength_start
|
||||||
|
self.strength_end = strength_end
|
||||||
|
self.cached_features = None
|
||||||
|
self.adapter_idx = 0
|
||||||
|
self.control_idx = 0
|
||||||
|
self.current_control_idx = 0
|
||||||
|
self.active = True
|
||||||
|
|
||||||
|
def _ensure_features(self, kwargs):
|
||||||
|
"""Run GLVControl on first call per step, cache results."""
|
||||||
|
if self.cached_features is not None:
|
||||||
|
return
|
||||||
|
x = kwargs["x"]
|
||||||
|
b = x.shape[0]
|
||||||
|
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
|
||||||
|
if hint.shape[0] != b:
|
||||||
|
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
|
||||||
|
self.cached_features = self.model_patch.model.control_model(
|
||||||
|
hint, kwargs["timesteps"], x,
|
||||||
|
kwargs["context"], kwargs["y"]
|
||||||
|
)
|
||||||
|
self.adapter_idx = len(self.project_modules) - 1
|
||||||
|
self.control_idx = len(self.cached_features) - 1
|
||||||
|
|
||||||
|
def _get_control_scale(self, kwargs):
|
||||||
|
if self.strength_start == self.strength_end:
|
||||||
|
return self.strength_end
|
||||||
|
sigma = kwargs["transformer_options"].get("sigmas")
|
||||||
|
if sigma is None:
|
||||||
|
return self.strength_end
|
||||||
|
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
|
||||||
|
t = min(s / self.SIGMA_MAX, 1.0)
|
||||||
|
return t * (self.strength_start - self.strength_end) + self.strength_end
|
||||||
|
|
||||||
|
def middle_after(self, kwargs):
|
||||||
|
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
|
||||||
|
self.cached_features = None # reset from previous step
|
||||||
|
self.current_scale = self._get_control_scale(kwargs)
|
||||||
|
self.active = self.current_scale > 0
|
||||||
|
if not self.active:
|
||||||
|
return {"h": kwargs["h"]}
|
||||||
|
self._ensure_features(kwargs)
|
||||||
|
h = kwargs["h"]
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return {"h": h}
|
||||||
|
|
||||||
|
def output_block(self, h, hsp, transformer_options):
|
||||||
|
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
|
||||||
|
if not self.active:
|
||||||
|
return h, hsp
|
||||||
|
self.current_control_idx = self.control_idx
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return h, None
|
||||||
|
|
||||||
|
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
|
||||||
|
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
|
||||||
|
block_type, _ = transformer_options["block"]
|
||||||
|
if block_type == "output" and self.active and self.cached_features is not None:
|
||||||
|
x = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
return layer(x, output_shape=output_shape)
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.cached_features = None
|
||||||
|
if self.hint_latent is not None:
|
||||||
|
self.hint_latent = self.hint_latent.to(device_or_dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
|
def register(self, model_patcher):
|
||||||
|
"""Register all patches on a cloned model patcher."""
|
||||||
|
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
|
||||||
|
model_patcher.set_model_output_block_patch(self.output_block)
|
||||||
|
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")
|
||||||
@ -506,6 +506,10 @@ class ModelPatcher:
|
|||||||
def set_model_noise_refiner_patch(self, patch):
|
def set_model_noise_refiner_patch(self, patch):
|
||||||
self.set_model_patch(patch, "noise_refiner")
|
self.set_model_patch(patch, "noise_refiner")
|
||||||
|
|
||||||
|
def set_model_middle_block_after_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "middle_block_after_patch")
|
||||||
|
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
|
|||||||
@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
p = fn(param)
|
p = fn(param)
|
||||||
if p.is_inference():
|
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||||
p = p.clone()
|
p = p.clone()
|
||||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
for key, buf in self._buffers.items():
|
for key, buf in self._buffers.items():
|
||||||
|
|||||||
@ -35,4 +35,4 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
|||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
return ErnieTEModel
|
return ErnieTEModel_
|
||||||
|
|||||||
@ -1066,7 +1066,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _seedance2_text_inputs():
|
def _seedance2_text_inputs(resolutions: list[str]):
|
||||||
return [
|
return [
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -1076,7 +1076,7 @@ def _seedance2_text_inputs():
|
|||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=["480p", "720p"],
|
options=resolutions,
|
||||||
tooltip="Resolution of the output video.",
|
tooltip="Resolution of the output video.",
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -1114,8 +1114,8 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||||
],
|
],
|
||||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
),
|
),
|
||||||
@ -1152,11 +1152,14 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
|||||||
(
|
(
|
||||||
$rate480 := 10044;
|
$rate480 := 10044;
|
||||||
$rate720 := 21600;
|
$rate720 := 21600;
|
||||||
|
$rate1080 := 48800;
|
||||||
$m := widgets.model;
|
$m := widgets.model;
|
||||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
$res := $lookup(widgets, "model.resolution");
|
$res := $lookup(widgets, "model.resolution");
|
||||||
$dur := $lookup(widgets, "model.duration");
|
$dur := $lookup(widgets, "model.duration");
|
||||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
$rate := $res = "1080p" ? $rate1080 :
|
||||||
|
$res = "720p" ? $rate720 :
|
||||||
|
$rate480;
|
||||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||||
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||||
)
|
)
|
||||||
@ -1195,6 +1198,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
|||||||
status_extractor=lambda r: r.status,
|
status_extractor=lambda r: r.status,
|
||||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||||
poll_interval=9,
|
poll_interval=9,
|
||||||
|
max_poll_attempts=180,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
@ -1212,8 +1216,8 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()),
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()),
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||||
],
|
],
|
||||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
),
|
),
|
||||||
@ -1259,11 +1263,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
(
|
(
|
||||||
$rate480 := 10044;
|
$rate480 := 10044;
|
||||||
$rate720 := 21600;
|
$rate720 := 21600;
|
||||||
|
$rate1080 := 48800;
|
||||||
$m := widgets.model;
|
$m := widgets.model;
|
||||||
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
$pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
$res := $lookup(widgets, "model.resolution");
|
$res := $lookup(widgets, "model.resolution");
|
||||||
$dur := $lookup(widgets, "model.duration");
|
$dur := $lookup(widgets, "model.duration");
|
||||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
$rate := $res = "1080p" ? $rate1080 :
|
||||||
|
$res = "720p" ? $rate720 :
|
||||||
|
$rate480;
|
||||||
$cost := $dur * $rate * $pricePer1K / 1000;
|
$cost := $dur * $rate * $pricePer1K / 1000;
|
||||||
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
{"type": "usd", "usd": $cost, "format": {"approximate": true}}
|
||||||
)
|
)
|
||||||
@ -1324,13 +1331,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
status_extractor=lambda r: r.status,
|
status_extractor=lambda r: r.status,
|
||||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||||
poll_interval=9,
|
poll_interval=9,
|
||||||
|
max_poll_attempts=180,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
def _seedance2_reference_inputs():
|
def _seedance2_reference_inputs(resolutions: list[str]):
|
||||||
return [
|
return [
|
||||||
*_seedance2_text_inputs(),
|
*_seedance2_text_inputs(resolutions),
|
||||||
IO.Autogrow.Input(
|
IO.Autogrow.Input(
|
||||||
"reference_images",
|
"reference_images",
|
||||||
template=IO.Autogrow.TemplateNames(
|
template=IO.Autogrow.TemplateNames(
|
||||||
@ -1382,8 +1390,8 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs()),
|
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
|
||||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs()),
|
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
|
||||||
],
|
],
|
||||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
),
|
),
|
||||||
@ -1423,13 +1431,16 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||||||
(
|
(
|
||||||
$rate480 := 10044;
|
$rate480 := 10044;
|
||||||
$rate720 := 21600;
|
$rate720 := 21600;
|
||||||
|
$rate1080 := 48800;
|
||||||
$m := widgets.model;
|
$m := widgets.model;
|
||||||
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
|
$hasVideo := $lookup(inputGroups, "model.reference_videos") > 0;
|
||||||
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
$noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001;
|
||||||
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
|
$videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149;
|
||||||
$res := $lookup(widgets, "model.resolution");
|
$res := $lookup(widgets, "model.resolution");
|
||||||
$dur := $lookup(widgets, "model.duration");
|
$dur := $lookup(widgets, "model.duration");
|
||||||
$rate := $res = "720p" ? $rate720 : $rate480;
|
$rate := $res = "1080p" ? $rate1080 :
|
||||||
|
$res = "720p" ? $rate720 :
|
||||||
|
$rate480;
|
||||||
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
|
$noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000;
|
||||||
$minVideoFactor := $ceil($dur * 5 / 3);
|
$minVideoFactor := $ceil($dur * 5 / 3);
|
||||||
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
|
$minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000;
|
||||||
@ -1559,6 +1570,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||||||
status_extractor=lambda r: r.status,
|
status_extractor=lambda r: r.status,
|
||||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
||||||
poll_interval=9,
|
poll_interval=9,
|
||||||
|
max_poll_attempts=180,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|||||||
@ -221,14 +221,17 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
|
||||||
|
obj_result = None
|
||||||
|
if obj_file_response:
|
||||||
|
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
f"{task_id}.glb",
|
f"{task_id}.glb",
|
||||||
await download_url_to_file_3d(
|
await download_url_to_file_3d(
|
||||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
),
|
),
|
||||||
obj_result.obj,
|
obj_result.obj if obj_result else None,
|
||||||
obj_result.texture,
|
obj_result.texture if obj_result else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -378,7 +381,9 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
|
||||||
|
if obj_file_response:
|
||||||
|
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
f"{task_id}.glb",
|
f"{task_id}.glb",
|
||||||
await download_url_to_file_3d(
|
await download_url_to_file_3d(
|
||||||
@ -390,6 +395,17 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
||||||
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
||||||
)
|
)
|
||||||
|
return IO.NodeOutput(
|
||||||
|
f"{task_id}.glb",
|
||||||
|
await download_url_to_file_3d(
|
||||||
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TencentModelTo3DUVNode(IO.ComfyNode):
|
class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||||
|
|||||||
@ -17,46 +17,12 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
from comfy_extras.nodes_images import SVG
|
from comfy_extras.nodes_images import SVG
|
||||||
|
|
||||||
|
_ARROW_MODELS = ["arrow-1.1", "arrow-1.1-max", "arrow-preview"]
|
||||||
|
|
||||||
class QuiverTextToSVGNode(IO.ComfyNode):
|
|
||||||
@classmethod
|
def _arrow_sampling_inputs():
|
||||||
def define_schema(cls):
|
"""Shared sampling inputs for all Arrow model variants."""
|
||||||
return IO.Schema(
|
return [
|
||||||
node_id="QuiverTextToSVGNode",
|
|
||||||
display_name="Quiver Text to SVG",
|
|
||||||
category="api node/image/Quiver",
|
|
||||||
description="Generate an SVG from a text prompt using Quiver AI.",
|
|
||||||
inputs=[
|
|
||||||
IO.String.Input(
|
|
||||||
"prompt",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Text description of the desired SVG output.",
|
|
||||||
),
|
|
||||||
IO.String.Input(
|
|
||||||
"instructions",
|
|
||||||
multiline=True,
|
|
||||||
default="",
|
|
||||||
tooltip="Additional style or formatting guidance.",
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.Autogrow.Input(
|
|
||||||
"reference_images",
|
|
||||||
template=IO.Autogrow.TemplatePrefix(
|
|
||||||
IO.Image.Input("image"),
|
|
||||||
prefix="ref_",
|
|
||||||
min=0,
|
|
||||||
max=4,
|
|
||||||
),
|
|
||||||
tooltip="Up to 4 reference images to guide the generation.",
|
|
||||||
optional=True,
|
|
||||||
),
|
|
||||||
IO.DynamicCombo.Input(
|
|
||||||
"model",
|
|
||||||
options=[
|
|
||||||
IO.DynamicCombo.Option(
|
|
||||||
"arrow-preview",
|
|
||||||
[
|
|
||||||
IO.Float.Input(
|
IO.Float.Input(
|
||||||
"temperature",
|
"temperature",
|
||||||
default=1.0,
|
default=1.0,
|
||||||
@ -87,9 +53,46 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
|||||||
tooltip="Token presence penalty.",
|
tooltip="Token presence penalty.",
|
||||||
advanced=True,
|
advanced=True,
|
||||||
),
|
),
|
||||||
],
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverTextToSVGNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="QuiverTextToSVGNode",
|
||||||
|
display_name="Quiver Text to SVG",
|
||||||
|
category="api node/image/Quiver",
|
||||||
|
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text description of the desired SVG output.",
|
||||||
),
|
),
|
||||||
],
|
IO.String.Input(
|
||||||
|
"instructions",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Additional style or formatting guidance.",
|
||||||
|
optional=True,
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplatePrefix(
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
prefix="ref_",
|
||||||
|
min=0,
|
||||||
|
max=4,
|
||||||
|
),
|
||||||
|
tooltip="Up to 4 reference images to guide the generation.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[IO.DynamicCombo.Option(m, _arrow_sampling_inputs()) for m in _ARROW_MODELS],
|
||||||
tooltip="Model to use for SVG generation.",
|
tooltip="Model to use for SVG generation.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -112,7 +115,16 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.429}""",
|
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$contains(widgets.model, "max")
|
||||||
|
? {"type":"usd","usd":0.3575}
|
||||||
|
: $contains(widgets.model, "preview")
|
||||||
|
? {"type":"usd","usd":0.429}
|
||||||
|
: {"type":"usd","usd":0.286}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -176,12 +188,13 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
|||||||
"auto_crop",
|
"auto_crop",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Automatically crop to the dominant subject.",
|
tooltip="Automatically crop to the dominant subject.",
|
||||||
|
advanced=True,
|
||||||
),
|
),
|
||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option(
|
IO.DynamicCombo.Option(
|
||||||
"arrow-preview",
|
m,
|
||||||
[
|
[
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"target_size",
|
"target_size",
|
||||||
@ -189,39 +202,12 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
|||||||
min=128,
|
min=128,
|
||||||
max=4096,
|
max=4096,
|
||||||
tooltip="Square resize target in pixels.",
|
tooltip="Square resize target in pixels.",
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"temperature",
|
|
||||||
default=1.0,
|
|
||||||
min=0.0,
|
|
||||||
max=2.0,
|
|
||||||
step=0.1,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Randomness control. Higher values increase randomness.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"top_p",
|
|
||||||
default=1.0,
|
|
||||||
min=0.05,
|
|
||||||
max=1.0,
|
|
||||||
step=0.05,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Nucleus sampling parameter.",
|
|
||||||
advanced=True,
|
|
||||||
),
|
|
||||||
IO.Float.Input(
|
|
||||||
"presence_penalty",
|
|
||||||
default=0.0,
|
|
||||||
min=-2.0,
|
|
||||||
max=2.0,
|
|
||||||
step=0.1,
|
|
||||||
display_mode=IO.NumberDisplay.slider,
|
|
||||||
tooltip="Token presence penalty.",
|
|
||||||
advanced=True,
|
advanced=True,
|
||||||
),
|
),
|
||||||
|
*_arrow_sampling_inputs(),
|
||||||
],
|
],
|
||||||
),
|
)
|
||||||
|
for m in _ARROW_MODELS
|
||||||
],
|
],
|
||||||
tooltip="Model to use for SVG vectorization.",
|
tooltip="Model to use for SVG vectorization.",
|
||||||
),
|
),
|
||||||
@ -245,7 +231,16 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.429}""",
|
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$contains(widgets.model, "max")
|
||||||
|
? {"type":"usd","usd":0.3575}
|
||||||
|
: $contains(widgets.model, "preview")
|
||||||
|
? {"type":"usd","usd":0.429}
|
||||||
|
: {"type":"usd","usd":0.286}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -401,7 +401,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.25}""",
|
expr="""{"type":"usd","usd":0.4}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -510,7 +510,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.25}""",
|
expr="""{"type":"usd","usd":0.6}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -593,7 +593,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.01}""",
|
expr="""{"type":"usd","usd":0.02}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -3,136 +3,136 @@ from typing_extensions import override
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
|
||||||
|
|
||||||
class TextEncodeAceStepAudio(io.ComfyNode):
|
class TextEncodeAceStepAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return IO.Schema(
|
||||||
node_id="TextEncodeAceStepAudio",
|
node_id="TextEncodeAceStepAudio",
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
IO.Clip.Input("clip"),
|
||||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||||
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output()],
|
outputs=[IO.Conditioning.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
|
||||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||||
return io.NodeOutput(conditioning)
|
return IO.NodeOutput(conditioning)
|
||||||
|
|
||||||
class TextEncodeAceStepAudio15(io.ComfyNode):
|
class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return IO.Schema(
|
||||||
node_id="TextEncodeAceStepAudio1.5",
|
node_id="TextEncodeAceStepAudio1.5",
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
IO.Clip.Input("clip"),
|
||||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||||
io.Int.Input("bpm", default=120, min=10, max=300),
|
IO.Int.Input("bpm", default=120, min=10, max=300),
|
||||||
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output()],
|
outputs=[IO.Conditioning.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
|
||||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
return io.NodeOutput(conditioning)
|
return IO.NodeOutput(conditioning)
|
||||||
|
|
||||||
|
|
||||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
class EmptyAceStepLatentAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return IO.Schema(
|
||||||
node_id="EmptyAceStepLatentAudio",
|
node_id="EmptyAceStepLatentAudio",
|
||||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||||
category="latent/audio",
|
category="latent/audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||||
io.Int.Input(
|
IO.Int.Input(
|
||||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[io.Latent.Output()],
|
outputs=[IO.Latent.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||||
length = int(seconds * 44100 / 512 / 8)
|
length = int(seconds * 44100 / 512 / 8)
|
||||||
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
||||||
|
|
||||||
|
|
||||||
class EmptyAceStep15LatentAudio(io.ComfyNode):
|
class EmptyAceStep15LatentAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return IO.Schema(
|
||||||
node_id="EmptyAceStep1.5LatentAudio",
|
node_id="EmptyAceStep1.5LatentAudio",
|
||||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||||
category="latent/audio",
|
category="latent/audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||||
io.Int.Input(
|
IO.Int.Input(
|
||||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[io.Latent.Output()],
|
outputs=[IO.Latent.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||||
length = round((seconds * 48000 / 1920))
|
length = round((seconds * 48000 / 1920))
|
||||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
||||||
|
|
||||||
class ReferenceAudio(io.ComfyNode):
|
class ReferenceAudio(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return IO.Schema(
|
||||||
node_id="ReferenceTimbreAudio",
|
node_id="ReferenceTimbreAudio",
|
||||||
display_name="Reference Audio",
|
display_name="Reference Audio",
|
||||||
category="advanced/conditioning/audio",
|
category="advanced/conditioning/audio",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
description="This node sets the reference audio for ace step 1.5",
|
description="This node sets the reference audio for ace step 1.5",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("conditioning"),
|
IO.Conditioning.Input("conditioning"),
|
||||||
io.Latent.Input("latent", optional=True),
|
IO.Latent.Input("latent", optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(),
|
IO.Conditioning.Output(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
|
def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
|
||||||
if latent is not None:
|
if latent is not None:
|
||||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
||||||
return io.NodeOutput(conditioning)
|
return IO.NodeOutput(conditioning)
|
||||||
|
|
||||||
class AceExtension(ComfyExtension):
|
class AceExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
TextEncodeAceStepAudio,
|
TextEncodeAceStepAudio,
|
||||||
EmptyAceStepLatentAudio,
|
EmptyAceStepLatentAudio,
|
||||||
|
|||||||
@ -7,7 +7,10 @@ import comfy.model_management
|
|||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.ldm.lumina.controlnet
|
import comfy.ldm.lumina.controlnet
|
||||||
|
import comfy.ldm.supir.supir_modules
|
||||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||||
|
from comfy_api.latest import io
|
||||||
|
from comfy.ldm.supir.supir_patch import SUPIRPatch
|
||||||
|
|
||||||
|
|
||||||
class BlockWiseControlBlock(torch.nn.Module):
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
@ -266,6 +269,27 @@ class ModelPatchLoader:
|
|||||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||||
device=comfy.model_management.unet_offload_device(),
|
device=comfy.model_management.unet_offload_device(),
|
||||||
operations=comfy.ops.manual_cast)
|
operations=comfy.ops.manual_cast)
|
||||||
|
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace = {}
|
||||||
|
if 'model.control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace["model.control_model."] = "control_model."
|
||||||
|
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||||
|
else:
|
||||||
|
prefix_replace["control_model."] = "control_model."
|
||||||
|
prefix_replace["project_modules."] = "project_modules."
|
||||||
|
|
||||||
|
# Extract denoise_encoder weights before filter_keys discards them
|
||||||
|
de_prefix = "first_stage_model.denoise_encoder."
|
||||||
|
denoise_encoder_sd = {}
|
||||||
|
for k in list(sd.keys()):
|
||||||
|
if k.startswith(de_prefix):
|
||||||
|
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
|
||||||
|
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||||
|
sd.pop("control_model.mask_LQ", None)
|
||||||
|
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
if denoise_encoder_sd:
|
||||||
|
model.denoise_encoder_sd = denoise_encoder_sd
|
||||||
|
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||||
@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRApply(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SUPIRApply",
|
||||||
|
category="model_patches/supir",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.ModelPatch.Input("model_patch"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||||
|
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
|
||||||
|
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
|
||||||
|
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
|
||||||
|
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||||
|
tooltip="Sigma threshold below which restore_cfg is disabled."),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
|
||||||
|
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
|
||||||
|
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
|
||||||
|
if not denoise_sd:
|
||||||
|
return vae.encode(image)
|
||||||
|
|
||||||
|
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
|
||||||
|
orig_patcher = vae.patcher
|
||||||
|
vae.patcher = orig_patcher.clone()
|
||||||
|
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
|
||||||
|
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
|
||||||
|
try:
|
||||||
|
return vae.encode(image)
|
||||||
|
finally:
|
||||||
|
vae.patcher = orig_patcher
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
|
||||||
|
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||||
|
model_patched = model.clone()
|
||||||
|
hint_latent = model.get_model_object("latent_format").process_in(
|
||||||
|
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
|
||||||
|
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||||
|
patch.register(model_patched)
|
||||||
|
|
||||||
|
if restore_cfg > 0.0:
|
||||||
|
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
|
||||||
|
latent_format = model.get_model_object("latent_format")
|
||||||
|
decoded = vae.decode(latent_format.process_out(hint_latent))
|
||||||
|
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
|
||||||
|
sigma_max = 14.6146
|
||||||
|
|
||||||
|
def restore_cfg_function(args):
|
||||||
|
denoised = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
if sigma.dim() > 0:
|
||||||
|
s = sigma[0].item()
|
||||||
|
else:
|
||||||
|
s = sigma.item()
|
||||||
|
if s > restore_cfg_s_tmin:
|
||||||
|
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
|
||||||
|
b = denoised.shape[0]
|
||||||
|
if ref.shape[0] != b:
|
||||||
|
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
|
||||||
|
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
|
||||||
|
d_center = denoised - ref
|
||||||
|
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
|
||||||
|
|
||||||
|
return io.NodeOutput(model_patched)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
"ZImageFunControlnet": ZImageFunControlnet,
|
"ZImageFunControlnet": ZImageFunControlnet,
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
|
"SUPIRApply": SUPIRApply,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from PIL import Image
|
|||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypedDict, Literal
|
from typing import TypedDict, Literal
|
||||||
|
import kornia
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(batched)
|
return io.NodeOutput(batched)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorTransfer(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ColorTransfer",
|
||||||
|
category="image/postprocessing",
|
||||||
|
description="Match the colors of one image to another using various algorithms.",
|
||||||
|
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||||
|
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||||
|
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||||
|
io.DynamicCombo.Input("source_stats",
|
||||||
|
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||||
|
options=[
|
||||||
|
io.DynamicCombo.Option("per_frame", []),
|
||||||
|
io.DynamicCombo.Option("uniform", []),
|
||||||
|
io.DynamicCombo.Option("target_frame", [
|
||||||
|
io.Int.Input("target_index", default=0, min=0, max=10000,
|
||||||
|
tooltip="Frame index used as the source baseline for computing the transform to image_ref"),
|
||||||
|
]),
|
||||||
|
]),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(display_name="image"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_lab(images, i, device):
|
||||||
|
return kornia.color.rgb_to_lab(
|
||||||
|
images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pool_stats(images, device, is_reinhard, eps):
|
||||||
|
"""Two-pass pooled mean + std/cov across all frames."""
|
||||||
|
N, C = images.shape[0], images.shape[3]
|
||||||
|
HW = images.shape[1] * images.shape[2]
|
||||||
|
mean = torch.zeros(C, 1, device=device, dtype=torch.float32)
|
||||||
|
for i in range(N):
|
||||||
|
mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True)
|
||||||
|
mean /= N
|
||||||
|
acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32)
|
||||||
|
for i in range(N):
|
||||||
|
centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean
|
||||||
|
if is_reinhard:
|
||||||
|
acc += (centered * centered).mean(dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
acc += centered @ centered.T / HW
|
||||||
|
if is_reinhard:
|
||||||
|
return mean, torch.sqrt(acc / N).clamp_min_(eps)
|
||||||
|
return mean, acc / N
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _frame_stats(lab_flat, hw, is_reinhard, eps):
|
||||||
|
"""Per-frame mean + std/cov."""
|
||||||
|
mean = lab_flat.mean(dim=-1, keepdim=True)
|
||||||
|
if is_reinhard:
|
||||||
|
return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps)
|
||||||
|
centered = lab_flat - mean
|
||||||
|
return mean, centered @ centered.T / hw
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mkl_matrix(cov_s, cov_r, eps):
|
||||||
|
"""Compute MKL 3x3 transform matrix from source and ref covariances."""
|
||||||
|
eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s)
|
||||||
|
sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps)
|
||||||
|
|
||||||
|
scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0)
|
||||||
|
mid = scaled_V.T @ cov_r @ scaled_V
|
||||||
|
eig_val_m, eig_vec_m = torch.linalg.eigh(mid)
|
||||||
|
sqrt_m = torch.sqrt(eig_val_m.clamp_min(0))
|
||||||
|
|
||||||
|
inv_sqrt_s = 1.0 / sqrt_val_s
|
||||||
|
inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0)
|
||||||
|
M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T
|
||||||
|
return inv_scaled_V @ M_half @ inv_scaled_V.T
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _histogram_lut(src, ref, bins=256):
|
||||||
|
"""Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1]."""
|
||||||
|
s_bins = (src * (bins - 1)).long().clamp(0, bins - 1)
|
||||||
|
r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1)
|
||||||
|
s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||||
|
r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||||
|
ones_s = torch.ones_like(src)
|
||||||
|
ones_r = torch.ones_like(ref)
|
||||||
|
s_hist.scatter_add_(1, s_bins, ones_s)
|
||||||
|
r_hist.scatter_add_(1, r_bins, ones_r)
|
||||||
|
s_cdf = s_hist.cumsum(1)
|
||||||
|
s_cdf = s_cdf / s_cdf[:, -1:]
|
||||||
|
r_cdf = r_hist.cumsum(1)
|
||||||
|
r_cdf = r_cdf / r_cdf[:, -1:]
|
||||||
|
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _pooled_cdf(cls, images, device, num_bins=256):
|
||||||
|
"""Build pooled CDF across all frames, one frame at a time."""
|
||||||
|
C = images.shape[3]
|
||||||
|
hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32)
|
||||||
|
for i in range(images.shape[0]):
|
||||||
|
frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||||
|
bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1)
|
||||||
|
hist.scatter_add_(1, bins, torch.ones_like(frame))
|
||||||
|
cdf = hist.cumsum(1)
|
||||||
|
return cdf / cdf[:, -1:]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B):
|
||||||
|
"""Build per-frame or uniform LUT transform for histogram mode."""
|
||||||
|
if stats_mode == 'per_frame':
|
||||||
|
return None # LUT computed per-frame in the apply loop
|
||||||
|
|
||||||
|
r_cdf = cls._pooled_cdf(image_ref, device)
|
||||||
|
if stats_mode == 'target_frame':
|
||||||
|
ti = min(target_index, B - 1)
|
||||||
|
s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device)
|
||||||
|
else:
|
||||||
|
s_cdf = cls._pooled_cdf(image_target, device)
|
||||||
|
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard):
|
||||||
|
"""Build transform parameters for Lab-based methods. Returns a transform function."""
|
||||||
|
eps = 1e-6
|
||||||
|
B, H, W, C = image_target.shape
|
||||||
|
B_ref = image_ref.shape[0]
|
||||||
|
single_ref = B_ref == 1
|
||||||
|
HW = H * W
|
||||||
|
HW_ref = image_ref.shape[1] * image_ref.shape[2]
|
||||||
|
|
||||||
|
# Precompute ref stats
|
||||||
|
if single_ref or stats_mode in ('uniform', 'target_frame'):
|
||||||
|
ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps)
|
||||||
|
|
||||||
|
# Uniform/target_frame: precompute single affine transform
|
||||||
|
if stats_mode in ('uniform', 'target_frame'):
|
||||||
|
if stats_mode == 'target_frame':
|
||||||
|
ti = min(target_index, B - 1)
|
||||||
|
s_lab = cls._to_lab(image_target, ti, device).view(C, -1)
|
||||||
|
s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps)
|
||||||
|
else:
|
||||||
|
s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps)
|
||||||
|
|
||||||
|
if is_reinhard:
|
||||||
|
scale = ref_sc / s_sc
|
||||||
|
offset = ref_mean - scale * s_mean
|
||||||
|
return lambda src_flat, **_: src_flat * scale + offset
|
||||||
|
T = cls._mkl_matrix(s_sc, ref_sc, eps)
|
||||||
|
offset = ref_mean - T @ s_mean
|
||||||
|
return lambda src_flat, **_: T @ src_flat + offset
|
||||||
|
|
||||||
|
# per_frame
|
||||||
|
def per_frame_transform(src_flat, frame_idx):
|
||||||
|
s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps)
|
||||||
|
|
||||||
|
if single_ref:
|
||||||
|
r_mean, r_sc = ref_mean, ref_sc
|
||||||
|
else:
|
||||||
|
ri = min(frame_idx, B_ref - 1)
|
||||||
|
r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps)
|
||||||
|
|
||||||
|
centered = src_flat - s_mean
|
||||||
|
if is_reinhard:
|
||||||
|
return centered * (r_sc / s_sc) + r_mean
|
||||||
|
T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps)
|
||||||
|
return T @ centered + r_mean
|
||||||
|
|
||||||
|
return per_frame_transform
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput:
|
||||||
|
stats_mode = source_stats["source_stats"]
|
||||||
|
target_index = source_stats.get("target_index", 0)
|
||||||
|
|
||||||
|
if strength == 0 or image_ref is None:
|
||||||
|
return io.NodeOutput(image_target)
|
||||||
|
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
intermediate_device = comfy.model_management.intermediate_device()
|
||||||
|
intermediate_dtype = comfy.model_management.intermediate_dtype()
|
||||||
|
|
||||||
|
B, H, W, C = image_target.shape
|
||||||
|
B_ref = image_ref.shape[0]
|
||||||
|
pbar = comfy.utils.ProgressBar(B)
|
||||||
|
out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
if method == 'histogram':
|
||||||
|
uniform_lut = cls._build_histogram_transform(
|
||||||
|
image_target, image_ref, device, stats_mode, target_index, B)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1)
|
||||||
|
src_flat = src.reshape(C, -1)
|
||||||
|
if uniform_lut is not None:
|
||||||
|
lut = uniform_lut
|
||||||
|
else:
|
||||||
|
ri = min(i, B_ref - 1)
|
||||||
|
ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||||
|
lut = cls._histogram_lut(src_flat, ref)
|
||||||
|
bin_idx = (src_flat * 255).long().clamp(0, 255)
|
||||||
|
matched = lut.gather(1, bin_idx).view(C, H, W)
|
||||||
|
result = matched if strength == 1.0 else torch.lerp(src, matched, strength)
|
||||||
|
out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
pbar.update(1)
|
||||||
|
else:
|
||||||
|
transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab")
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
src_frame = cls._to_lab(image_target, i, device)
|
||||||
|
corrected = transform(src_frame.view(C, -1), frame_idx=i)
|
||||||
|
if strength == 1.0:
|
||||||
|
result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W))
|
||||||
|
else:
|
||||||
|
result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength))
|
||||||
|
out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
return io.NodeOutput(out)
|
||||||
|
|
||||||
|
|
||||||
class PostProcessingExtension(ComfyExtension):
|
class PostProcessingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension):
|
|||||||
BatchImagesNode,
|
BatchImagesNode,
|
||||||
BatchMasksNode,
|
BatchMasksNode,
|
||||||
BatchLatentsNode,
|
BatchLatentsNode,
|
||||||
|
ColorTransfer,
|
||||||
# BatchImagesMasksLatentsNode,
|
# BatchImagesMasksLatentsNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import json
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
@ -375,6 +376,39 @@ class RegexReplace(io.ComfyNode):
|
|||||||
return io.NodeOutput(result)
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonExtractString(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="JsonExtractString",
|
||||||
|
display_name="Extract String from JSON",
|
||||||
|
category="utils/string",
|
||||||
|
search_aliases=["json", "extract json", "parse json", "json value", "read json"],
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("json_string", multiline=True),
|
||||||
|
io.String.Input("key", multiline=False),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, json_string, key):
|
||||||
|
try:
|
||||||
|
data = json.loads(json_string)
|
||||||
|
if isinstance(data, dict) and key in data:
|
||||||
|
value = data[key]
|
||||||
|
if value is None:
|
||||||
|
return io.NodeOutput("")
|
||||||
|
|
||||||
|
return io.NodeOutput(str(value))
|
||||||
|
|
||||||
|
return io.NodeOutput("")
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return io.NodeOutput("")
|
||||||
|
|
||||||
class StringExtension(ComfyExtension):
|
class StringExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -390,6 +424,7 @@ class StringExtension(ComfyExtension):
|
|||||||
RegexMatch,
|
RegexMatch,
|
||||||
RegexExtract,
|
RegexExtract,
|
||||||
RegexReplace,
|
RegexReplace,
|
||||||
|
JsonExtractString,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> StringExtension:
|
async def comfy_entrypoint() -> StringExtension:
|
||||||
|
|||||||
@ -37,6 +37,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||||
|
io.Boolean.Input("use_default_template", optional=True, default=True, tooltip="Use the built in system prompt/template if the model has one.", advanced=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.String.Output(display_name="generated_text"),
|
io.String.Output(display_name="generated_text"),
|
||||||
@ -44,9 +45,9 @@ class TextGenerate(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||||
|
|
||||||
tokens = clip.tokenize(prompt, image=image, video=video, audio=audio, skip_template=False, min_length=1, thinking=thinking)
|
tokens = clip.tokenize(prompt, image=image, video=video, audio=audio, skip_template=not use_default_template, min_length=1, thinking=thinking)
|
||||||
|
|
||||||
# Get sampling parameters from dynamic combo
|
# Get sampling parameters from dynamic combo
|
||||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||||
@ -163,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
else:
|
else:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking)
|
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking, use_default_template)
|
||||||
|
|
||||||
|
|
||||||
class TextgenExtension(ComfyExtension):
|
class TextgenExtension(ComfyExtension):
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.19.0"
|
__version__ = "0.19.3"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.19.0"
|
version = "0.19.3"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.42.11
|
comfyui-frontend-package==1.42.12
|
||||||
comfyui-workflow-templates==0.9.50
|
comfyui-workflow-templates==0.9.57
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -19,7 +19,7 @@ scipy
|
|||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
alembic
|
alembic
|
||||||
SQLAlchemy
|
SQLAlchemy>=2.0
|
||||||
filelock
|
filelock
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.8
|
comfy-kitchen>=0.2.8
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user