Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-12-03 16:15:09 +03:00 committed by GitHub
commit 894604b268
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 958 additions and 266 deletions

View File

@ -66,8 +66,10 @@ if branch is None:
try: try:
ref = repo.lookup_reference('refs/remotes/origin/master') ref = repo.lookup_reference('refs/remotes/origin/master')
except: except:
print("pulling.") # noqa: T201 print("fetching.") # noqa: T201
pull(repo) for remote in repo.remotes:
if remote.name == "origin":
remote.fetch()
ref = repo.lookup_reference('refs/remotes/origin/master') ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref) repo.checkout(ref)
branch = repo.lookup_branch('master') branch = repo.lookup_branch('master')
@ -149,3 +151,4 @@ try:
shutil.copy(stable_update_script, stable_update_script_to) shutil.copy(stable_update_script, stable_update_script_to)
except: except:
pass pass

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module): class SRResidualCausalBlock3D(nn.Module):

View File

@ -1,42 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
import comfy.ops import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
import comfy.model_management import comfy.model_management
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim): def __init__(self, dim):
@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module): class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2 fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0 assert oc % fct == 0
@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2 fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
return h + x return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@ -191,7 +144,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = NoPadConv3d conv_op = CarriedConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -206,9 +159,10 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
conv_op=conv_op, norm_op=norm_op) temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
if i < depth: if i < depth:
@ -218,9 +172,9 @@ class Encoder(nn.Module):
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch) self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -246,22 +200,20 @@ class Encoder(nn.Module):
conv_carry_out = [] conv_carry_out = []
if i == len(x) - 1: if i == len(x) - 1:
conv_carry_out = None conv_carry_out = None
x1 = [ x1 ] x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down: for stage in self.down:
for blk in stage.block: for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out) x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'): if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1) out.append(x1)
conv_carry_in = conv_carry_out conv_carry_in = conv_carry_out
if len(out) > 1: out = torch_cat_if_needed(out, dim=2)
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out del out
@ -288,7 +240,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = NoPadConv3d conv_op = CarriedConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -298,9 +250,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -308,9 +260,10 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
conv_op=conv_op, norm_op=norm_op) temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
if i < depth: if i < depth:
@ -340,7 +293,7 @@ class Decoder(nn.Module):
conv_carry_out = None conv_carry_out = None
for stage in self.up: for stage in self.up:
for blk in stage.block: for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out) x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'): if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@ -350,10 +303,7 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out conv_carry_in = conv_carry_out
del x del x
if len(out) > 1: out = torch_cat_if_needed(out, dim=2)
out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae: if not self.refiner_vae:
if z.shape[-3] == 1: if z.shape[-3] == 1:

View File

@ -0,0 +1,113 @@
import torch
from torch import nn
from .model import JointTransformerBlock
class ZImageControlTransformerBlock(JointTransformerBlock):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0,
operation_settings=None,
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
self.block_id = block_id
if block_id == 0:
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c_skip, c
class ZImage_Control(torch.nn.Module):
def __init__(
self,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
multiple_of: int = 256,
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
dtype=None,
device=None,
operations=None,
**kwargs
):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.additional_in_dim = 0
self.control_in_dim = 16
n_refiner_layers = 2
self.n_control_layers = 6
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
i,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=i,
operation_settings=operation_settings,
)
for i in range(self.n_control_layers)
]
)
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2
f_patch_size = 1
pH = pW = patch_size
B, C, H, W = control_context.shape
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -568,7 +568,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask): # def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps t = 1.0 - timesteps
cap_feats = context cap_feats = context
cap_mask = attention_mask cap_mask = attention_mask
@ -585,16 +585,24 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
patches = transformer_options.get("patches", {})
transformer_options = kwargs.get("transformer_options", {}) transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor) x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device) freqs_cis = freqs_cis.to(img.device)
for layer in self.layers: for i, layer in enumerate(self.layers):
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
x = self.final_layer(x, adaln_input) img = self.final_layer(img, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -x return -img

View File

@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
def torch_cat_if_needed(xl, dim):
if len(xl) > 1:
return torch.cat(xl, dim)
else:
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class CarriedConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, CarriedConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
out = op(x)
return out
class VideoConv3d(nn.Module): class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__() super().__init__()
@ -89,29 +126,24 @@ class Upsample(nn.Module):
stride=1, stride=1,
padding=1) padding=1)
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)): if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2) scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0: if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2] results = []
if t > 1: if conv_carry_in is None:
a, b = x.split((1, t - 1), dim=2) first = x[:, :, :1, :, :]
del x results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
b = interpolate_up(b, scale_factor) x = x[:, :, 1:, :, :]
else: if x.shape[2] > 0:
a = x results.append(interpolate_up(x, scale_factor))
x = torch_cat_if_needed(results, dim=2)
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else: else:
x = interpolate_up(x, scale_factor) x = interpolate_up(x, scale_factor)
if self.with_conv: if self.with_conv:
x = self.conv(x) x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x return x
@ -127,17 +159,20 @@ class Downsample(nn.Module):
stride=stride, stride=stride,
padding=0) padding=0)
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv: if self.with_conv:
if x.ndim == 4: if isinstance(self.conv, CarriedConv3d):
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
elif x.ndim == 4:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
mode = "constant" mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0) x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
x = self.conv(x)
elif x.ndim == 5: elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0) pad = (1, 1, 1, 1, 2, 0)
mode = "replicate" mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode) x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x return x
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
def forward(self, x, temb=None): def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x h = x
h = self.norm1(h) h = self.norm1(h)
h = self.swish(h) h = [ self.swish(h) ]
h = self.conv1(h) h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None: if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h) h = self.norm2(h)
h = self.swish(h) h = self.swish(h)
h = self.dropout(h) h = [ self.dropout(h) ]
h = self.conv2(h) h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
x = self.conv_shortcut(x) x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else: else:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
@ -520,9 +555,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
self.carried = False
if conv3d: if conv3d:
conv_op = VideoConv3d if not attn_resolutions:
conv_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d mid_attn_conv_op = ops.Conv3d
else: else:
conv_op = ops.Conv2d conv_op = ops.Conv2d
@ -535,6 +575,7 @@ class Encoder(nn.Module):
stride=1, stride=1,
padding=1) padding=1)
self.time_compress = 1
curr_res = resolution curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult) in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult self.in_ch_mult = in_ch_mult
@ -561,10 +602,15 @@ class Encoder(nn.Module):
if time_compress is not None: if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2) stride = (1, 2, 2)
else:
self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
if time_compress is not None:
self.time_compress = time_compress
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(in_channels=block_in,
@ -590,15 +636,42 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
# downsampling
h = self.conv_in(x) if self.carried:
for i_level in range(self.num_resolutions): xl = [x[:, :, :1, :, :]]
for i_block in range(self.num_res_blocks): if x.shape[2] > self.time_compress:
h = self.down[i_level].block[i_block](h, temb) tc = self.time_compress
if len(self.down[i_level].attn) > 0: xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
h = self.down[i_level].attn[i_block](h) x = xl
if i_level != self.num_resolutions-1: else:
h = self.down[i_level].downsample(h) x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
# downsampling
x1 = [ x1 ]
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
if len(self.down[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.down[i_level].attn[i_block](h1)
if i_level != self.num_resolutions-1:
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
out.append(h1)
conv_carry_in = conv_carry_out
h = torch_cat_if_needed(out, dim=2)
del out
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb)
@ -607,15 +680,15 @@ class Encoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = [ nonlinearity(h) ]
h = self.conv_out(h) h = conv_carry_causal_3d(h, self.conv_out)
return h return h
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d, conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock, resnet_op=ResnetBlock,
attn_op=AttnBlock, attn_op=AttnBlock,
@ -629,12 +702,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out self.tanh_out = tanh_out
self.carried = False
if conv3d: if conv3d:
conv_op = VideoConv3d if not attn_resolutions and resnet_op == ResnetBlock:
conv_out_op = VideoConv3d conv_op = CarriedConv3d
conv_out_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d mid_attn_conv_op = ops.Conv3d
else: else:
conv_op = ops.Conv2d conv_op = ops.Conv2d
@ -709,29 +788,43 @@ class Decoder(nn.Module):
temb = None temb = None
# z to block_in # z to block_in
h = self.conv_in(z) h = conv_carry_causal_3d([z], self.conv_in)
# middle # middle
h = self.mid.block_1(h, temb, **kwargs) h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs) h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs) h = self.mid.block_2(h, temb, **kwargs)
if self.carried:
h = torch.split(h, 2, dim=2)
else:
h = [ h ]
out = []
conv_carry_in = None
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i, h1 in enumerate(h):
for i_block in range(self.num_res_blocks+1): conv_carry_out = []
h = self.up[i_level].block[i_block](h, temb, **kwargs) if i == len(h) - 1:
if len(self.up[i_level].attn) > 0: conv_carry_out = None
h = self.up[i_level].attn[i_block](h, **kwargs) for i_level in reversed(range(self.num_resolutions)):
if i_level != 0: for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].upsample(h) h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
if len(self.up[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
if i_level != 0:
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
# end h1 = self.norm_out(h1)
if self.give_pre_end: h1 = [ nonlinearity(h1) ]
return h h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
if self.tanh_out:
h1 = torch.tanh(h1)
out.append(h1)
conv_carry_in = conv_carry_out
h = self.norm_out(h) out = torch_cat_if_needed(out, dim=2)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs) return out
if self.tanh_out:
h = torch.tanh(h)
return h

View File

@ -704,7 +704,7 @@ class ModelPatcher:
lowvram_weight = False lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
@ -883,7 +883,7 @@ class ModelPatcher:
break break
module_offload_mem, module_mem, n, m, params = unload module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)
lowvram_possible = hasattr(m, "comfy_cast_weights") lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:

View File

@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None: if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function: comfy.model_management.sync_stream(device, offload_stream)
with wf_context:
for f in s.bias_function: bias_a = bias
bias = f(bias) weight_a = weight
if s.bias is not None:
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype: if weight_has_function or weight.dtype != dtype:
with wf_context: weight = weight.to(dtype=dtype)
weight = weight.to(dtype=dtype) if isinstance(weight, QuantizedTensor):
if isinstance(weight, QuantizedTensor): weight = weight.dequantize()
weight = weight.dequantize() for f in s.weight_function:
for f in s.weight_function: weight = f(weight)
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable: if offloadable:
return weight, bias, offload_stream return weight, bias, (offload_stream, weight_a, bias_a)
else: else:
#Legacy function signature #Legacy function signature
return weight, bias return weight, bias
@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream): def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None: if offload_stream is None:
return return
if weight is not None: os, weight_a, bias_a = offload_stream
device = weight.device if os is None:
return
if weight_a is not None:
device = weight_a.device
else: else:
if bias is None: if bias_a is None:
return return
device = bias.device device = bias_a.device
offload_stream.wait_stream(comfy.model_management.current_stream(device)) os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp:

View File

@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io as io from . import _io_public as io
from . import _ui as ui from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple

View File

@ -4,6 +4,7 @@ import copy
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import Counter from collections import Counter
from collections.abc import Iterable
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
@ -150,6 +151,9 @@ class _IO_V3:
def __init__(self): def __init__(self):
pass pass
def validate(self):
pass
@property @property
def io_type(self): def io_type(self):
return self.Parent.io_type return self.Parent.io_type
@ -182,6 +186,9 @@ class Input(_IO_V3):
def get_io_type(self): def get_io_type(self):
return _StringIOType(self.io_type) return _StringIOType(self.io_type)
def get_all(self) -> list[Input]:
return [self]
class WidgetInput(Input): class WidgetInput(Input):
''' '''
Base class for a V3 Input with widget. Base class for a V3 Input with widget.
@ -814,13 +821,61 @@ class MultiType:
else: else:
return super().as_dict() return super().as_dict()
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
self.template_id = template_id
# account for syntactic sugar
if not isinstance(allowed_types, Iterable):
allowed_types = [allowed_types]
for t in allowed_types:
if not isinstance(t, type):
if not isinstance(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
else:
if not issubclass(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
self.allowed_types = allowed_types
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
}
class Input(Input):
def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class DynamicInput(Input, ABC): class DynamicInput(Input, ABC):
''' '''
Abstract class for dynamic input registration. Abstract class for dynamic input registration.
''' '''
@abstractmethod
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
... return []
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
pass
class DynamicOutput(Output, ABC): class DynamicOutput(Output, ABC):
''' '''
@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC):
is_output_list=False): is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
@abstractmethod
def get_dynamic(self) -> list[Output]: def get_dynamic(self) -> list[Output]:
... return []
@comfytype(io_type="COMFY_AUTOGROW_V3") @comfytype(io_type="COMFY_AUTOGROW_V3")
class AutogrowDynamic(ComfyTypeI): class Autogrow(ComfyTypeI):
Type = list[Any] Type = dict[str, Any]
class Input(DynamicInput): _MaxNames = 100 # NOTE: max 100 names for sanity
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): class _AutogrowTemplate:
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) def __init__(self, input: Input):
self.template_input = template_input # dynamic inputs are not allowed as the template input
if min is not None: assert(not isinstance(input, DynamicInput))
assert(min >= 1) self.input = copy.copy(input)
if max is not None: if isinstance(self.input, WidgetInput):
assert(max >= 1) self.input.force_input = True
self.names: list[str] = []
self.cached_inputs = {}
def _create_input(self, input: Input, name: str):
new_input = copy.copy(self.input)
new_input.id = name
return new_input
def _create_cached_inputs(self):
for name in self.names:
self.cached_inputs[name] = self._create_input(self.input, name)
def get_all(self) -> list[Input]:
return list(self.cached_inputs.values())
def as_dict(self):
return prune_dict({
"input": create_input_dict_v1([self.input]),
})
def validate(self):
self.input.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
real_inputs = []
for name, input in self.cached_inputs.items():
if name in live_inputs:
real_inputs.append(input)
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input)
self.prefix = prefix
assert(min >= 0)
assert(max >= 1)
assert(max <= Autogrow._MaxNames)
self.min = min self.min = min
self.max = max self.max = max
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"prefix": self.prefix,
"min": self.min,
"max": self.max,
})
class TemplateNames(_AutogrowTemplate):
def __init__(self, input: Input, names: list[str], min: int=1):
super().__init__(input)
self.names = names[:Autogrow._MaxNames]
assert(min >= 0)
self.min = min
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"names": self.names,
"min": self.min,
})
class Input(DynamicInput):
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
curr_count = 1 return self.template.get_all()
new_inputs = []
for i in range(self.min):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = self.optional or new_input.optional
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
# pretend to expand up to max
for i in range(curr_count-1, self.max):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = True
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
return new_inputs
@comfytype(io_type="COMFY_COMBODYNAMIC_V3") def get_all(self) -> list[Input]:
class ComboDynamic(ComfyTypeI): return [self] + self.template.get_all()
class Input(DynamicInput):
def __init__(self, id: str):
pass
@comfytype(io_type="COMFY_MATCHTYPE_V3") def validate(self):
class MatchType(ComfyTypeIO): self.template.validate()
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
self.template_id = template_id curr_prefix = f"{curr_prefix}{self.id}."
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types # need to remove self from expected inputs dictionary; replaced by template inputs in frontend
for inner_dict in d.values():
if self.id in inner_dict:
del inner_dict[self.id]
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI):
Type = dict[str, Any]
class Option:
def __init__(self, key: str, inputs: list[Input]):
self.key = key
self.inputs = inputs
def as_dict(self): def as_dict(self):
return { return {
"template_id": self.template_id, "key": self.key,
"allowed_types": "".join(t.io_type for t in self.allowed_types), "inputs": create_input_dict_v1(self.inputs),
} }
class Input(DynamicInput): class Input(DynamicInput):
def __init__(self, id: str, template: MatchType.Template, def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template self.options = options
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
return [self] return [input for option in self.options for input in option.inputs]
def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs]
def as_dict(self): def as_dict(self):
return super().as_dict() | prune_dict({ return super().as_dict() | prune_dict({
"template": self.template.as_dict(), "options": [o.as_dict() for o in self.options],
}) })
class Output(DynamicOutput): def validate(self):
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, # make sure all nested inputs are validated
is_output_list=False): for option in self.options:
super().__init__(id, display_name, tooltip, is_output_list) for input in option.inputs:
self.template = template input.validate()
def get_dynamic(self) -> list[Output]: @comfytype(io_type="COMFY_DYNAMICSLOT_V3")
return [self] class DynamicSlot(ComfyTypeI):
Type = dict[str, Any]
class Input(DynamicInput):
def __init__(self, slot: Input, inputs: list[Input],
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
assert(not isinstance(slot, DynamicInput))
self.slot = copy.copy(slot)
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
optional = True
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
self.inputs = inputs
self.force_input = None
# force widget inputs to have no widgets, otherwise this would be awkward
if isinstance(self.slot, WidgetInput):
self.force_input = True
self.slot.force_input = True
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
def get_dynamic(self) -> list[Input]:
return [self.slot] + self.inputs
def get_all(self) -> list[Input]:
return [self] + [self.slot] + self.inputs
def as_dict(self): def as_dict(self):
return super().as_dict() | prune_dict({ return super().as_dict() | prune_dict({
"template": self.template.as_dict(), "slotType": str(self.slot.get_io_type()),
"inputs": create_input_dict_v1(self.inputs),
"forceInput": self.force_input,
}) })
def validate(self):
self.slot.validate()
for input in self.inputs:
input.validate()
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
dynamic = d.setdefault("dynamic_paths", {})
if self is not None:
dynamic[self.id] = f"{curr_prefix}{self.id}"
for i in inputs:
if not isinstance(i, DynamicInput):
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
class V3Data(TypedDict):
hidden_inputs: dict[str, Any]
dynamic_paths: dict[str, Any]
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -984,6 +1163,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None output_is_list: list[bool]=None
output_name: list[str]=None output_name: list[str]=None
output_tooltips: list[str]=None output_tooltips: list[str]=None
output_matchtypes: list[str]=None
name: str=None name: str=None
display_name: str=None display_name: str=None
description: str=None description: str=None
@ -1061,7 +1241,11 @@ class Schema:
'''Validate the schema: '''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other - verify ids on inputs and outputs are unique - both internally and in relation to each other
''' '''
input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] nested_inputs: list[Input] = []
if self.inputs is not None:
for input in self.inputs:
nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids) input_set = set(input_ids)
output_set = set(output_ids) output_set = set(output_ids)
@ -1077,6 +1261,13 @@ class Schema:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0: if len(issues) > 0:
raise ValueError("\n".join(issues)) raise ValueError("\n".join(issues))
# validate inputs and outputs
if self.inputs is not None:
for input in self.inputs:
input.validate()
if self.outputs is not None:
for output in self.outputs:
output.validate()
def finalize(self): def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids.""" """Add hidden based on selected schema options, and give outputs without ids default ids."""
@ -1102,19 +1293,10 @@ class Schema:
if output.id is None: if output.id is None:
output.id = f"_{i}_{output.io_type}_" output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls) -> NodeInfoV1: def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs # get V1 inputs
input = { input = create_input_dict_v1(self.inputs, live_inputs)
"required": {}
}
if self.inputs:
for i in self.inputs:
if isinstance(i, DynamicInput):
dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
if self.hidden: if self.hidden:
for hidden in self.hidden: for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1123,12 +1305,24 @@ class Schema:
output_is_list = [] output_is_list = []
output_name = [] output_name = []
output_tooltips = [] output_tooltips = []
output_matchtypes = []
any_matchtypes = False
if self.outputs: if self.outputs:
for o in self.outputs: for o in self.outputs:
output.append(o.io_type) output.append(o.io_type)
output_is_list.append(o.is_output_list) output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type) output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None) output_tooltips.append(o.tooltip if o.tooltip else None)
# special handling for MatchType
if isinstance(o, MatchType.Output):
output_matchtypes.append(o.template.template_id)
any_matchtypes = True
else:
output_matchtypes.append(None)
# clear out lists that are all None
if not any_matchtypes:
output_matchtypes = None
info = NodeInfoV1( info = NodeInfoV1(
input=input, input=input,
@ -1137,6 +1331,7 @@ class Schema:
output_is_list=output_is_list, output_is_list=output_is_list,
output_name=output_name, output_name=output_name,
output_tooltips=output_tooltips, output_tooltips=output_tooltips,
output_matchtypes=output_matchtypes,
name=self.node_id, name=self.node_id,
display_name=self.display_name, display_name=self.display_name,
category=self.category, category=self.category,
@ -1182,16 +1377,57 @@ class Schema:
return info return info
def add_to_dict_v1(i: Input, input: dict): def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
input = {
"required": {}
}
add_to_input_dict_v1(input, inputs, live_inputs)
return input
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required" key = "optional" if i.optional else "required"
as_dict = i.as_dict() as_dict = i.as_dict()
# for v1, we don't want to include the optional key # for v1, we don't want to include the optional key
as_dict.pop("optional", None) as_dict.pop("optional", None)
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) if dynamic_dict is None:
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict): def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict()) d[io.id] = (io.get_io_type(), io.as_dict())
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
if paths is None:
return values
values = values.copy()
result = {}
for key, path in paths.items():
parts = path.split(".")
current = result
for i, p in enumerate(parts):
is_last = (i == len(parts) - 1)
if is_last:
current[p] = values.pop(key, None)
else:
current = current.setdefault(p, {})
values.update(result)
return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal): class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching.""" """Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type) type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden # set hidden
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone return type_clone
@final @final
@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA() schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls) info = schema.get_v1_info(cls, live_inputs)
input = info.input input = info.input
if not include_hidden: if not include_hidden:
input.pop("hidden", None) input.pop("hidden", None)
if return_schema: if return_schema:
return input, schema v3_data: V3Data = {}
dynamic = input.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
return input, schema, v3_data
return input return input
@final @final
@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def validate_inputs(cls, **kwargs) -> bool: def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError raise NotImplementedError
@ -1628,6 +1868,7 @@ __all__ = [
"StyleModel", "StyleModel",
"Gligen", "Gligen",
"UpscaleModel", "UpscaleModel",
"LatentUpscaleModel",
"Audio", "Audio",
"Video", "Video",
"SVG", "SVG",
@ -1651,6 +1892,10 @@ __all__ = [
"SEGS", "SEGS",
"AnyType", "AnyType",
"MultiType", "MultiType",
# Dynamic Types
"MatchType",
# "DynamicCombo",
# "Autogrow",
# Other classes # Other classes
"HiddenHolder", "HiddenHolder",
"Hidden", "Hidden",
@ -1661,4 +1906,5 @@ __all__ = [
"NodeOutput", "NodeOutput",
"add_to_dict_v1", "add_to_dict_v1",
"add_to_dict_v3", "add_to_dict_v3",
"V3Data",
] ]

View File

@ -0,0 +1 @@
from ._io import * # noqa: F403

View File

@ -0,0 +1 @@
from ._ui import * # noqa: F403

View File

@ -6,7 +6,7 @@ from comfy_api.latest import (
) )
from typing import Type, TYPE_CHECKING from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@ -42,4 +42,8 @@ __all__ = [
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension", "ComfyExtension",
"io",
"IO",
"ui",
"UI",
] ]

View File

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
from comfy_api.latest import IO
def validate_node_input( def validate_node_input(
@ -23,6 +24,11 @@ def validate_node_input(
if not received_type != input_type: if not received_type != input_type:
return True return True
# If the received type or input_type is a MatchType, we can return True immediately;
# validation for this is handled by the frontend
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
return True
# Not equal, and not strings # Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str): if not isinstance(received_type, str) or not isinstance(input_type, str):
return False return False

155
comfy_extras/nodes_logic.py Normal file
View File

@ -0,0 +1,155 @@
from typing import TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.latest import _io
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
],
outputs=[
io.MatchType.Output(template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=..., on_true=...):
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
# This trick allows us to ignore the value of the switch and still be able to run execute().
# One of the inputs may be missing, in which case we need to evaluate the other input
if on_false is ...:
return ["on_true"]
if on_true is ...:
return ["on_false"]
# Normal lazy switch operation
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def validate_inputs(cls, switch, on_false=..., on_true=...):
# This check happens before check_lazy_status(), so we can eliminate the case where
# both inputs are missing.
if on_false is ... and on_true is ...:
return "At least one of on_false or on_true must be connected to Switch node"
return True
@classmethod
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
if on_true is ...:
return io.NodeOutput(on_false)
if on_false is ...:
return io.NodeOutput(on_true)
return io.NodeOutput(on_true if switch else on_false)
class DCTestNode(io.ComfyNode):
class DCValues(TypedDict):
combo: str
string: str
integer: int
image: io.Image.Type
subcombo: dict[str]
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="logic",
is_output_node=True,
inputs=[_io.DynamicCombo.Input("combo", options=[
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
_io.DynamicCombo.Option("option4", [
_io.DynamicCombo.Input("subcombo", options=[
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
])
])]
)],
outputs=[io.AnyType.Output()],
)
@classmethod
def execute(cls, combo: DCValues) -> io.NodeOutput:
combo_val = combo["combo"]
if combo_val == "option1":
return io.NodeOutput(combo["string"])
elif combo_val == "option2":
return io.NodeOutput(combo["integer"])
elif combo_val == "option3":
return io.NodeOutput(combo["image"])
elif combo_val == "option4":
return io.NodeOutput(f"{combo['subcombo']}")
else:
raise ValueError(f"Invalid combo: {combo_val}")
class AutogrowNamesTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
vals = list(autogrow.values())
combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined)
class AutogrowPrefixTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
vals = list(autogrow.values())
combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined)
class LogicExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
# SwitchNode,
# DCTestNode,
# AutogrowNamesTestNode,
# AutogrowPrefixTestNode,
]
async def comfy_entrypoint() -> LogicExtension:
return LogicExtension()

View File

@ -6,6 +6,7 @@ import comfy.ops
import comfy.model_management import comfy.model_management
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.latent_formats import comfy.latent_formats
import comfy.ldm.lumina.controlnet
class BlockWiseControlBlock(torch.nn.Module): class BlockWiseControlBlock(torch.nn.Module):
@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
return embedding return embedding
def z_image_convert(sd):
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
".attention.norm_k.weight": ".attention.k_norm.weight",
".attention.norm_q.weight": ".attention.q_norm.weight",
".attention.to_out.0.weight": ".attention.out.weight"
}
out_sd = {}
for k in sorted(sd.keys()):
w = sd[k]
k_out = k
if k_out.endswith(".attention.to_k.weight"):
cc = [w]
continue
if k_out.endswith(".attention.to_q.weight"):
cc = [w] + cc
continue
if k_out.endswith(".attention.to_v.weight"):
cc = cc + [w]
w = torch.cat(cc, dim=0)
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
for r, rr in replace_keys.items():
k_out = k_out.replace(r, rr)
out_sd[k_out] = w
return out_sd
class ModelPatchLoader: class ModelPatchLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -211,6 +241,9 @@ class ModelPatchLoader:
elif 'feature_embedder.mid_layer_norm.bias' in sd: elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd)
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd) model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
def models(self): def models(self):
return [self.model_patch] return [self.model_patch]
class ZImageControlPatch:
def __init__(self, model_patch, vae, image, strength):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.strength = strength
self.encoded_image = self.encode_latent_cond(image)
self.encoded_image_size = (image.shape[1], image.shape[2])
self.temp_data = None
def encode_latent_cond(self, image):
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
txt = kwargs.get("txt")
pe = kwargs.get("pe")
vec = kwargs.get("vec")
block_index = kwargs.get("block_index")
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
comfy.model_management.load_models_gpu(loaded_models)
cnet_index = (block_index // 5)
cnet_index_float = (block_index / 5)
kwargs.pop("img") # we do ops in place
kwargs.pop("txt")
cnet_blocks = self.model_patch.model.n_control_layers
if cnet_index_float > (cnet_blocks - 1):
self.temp_data = None
return kwargs
if self.temp_data is None or self.temp_data[0] > cnet_index:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if cnet_index_float == self.temp_data[0]:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
if cnet_blocks == self.temp_data[0] + 1:
self.temp_data = None
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
self.temp_data = None
return self
def models(self):
return [self.model_patch]
class QwenImageDiffsynthControlnet: class QwenImageDiffsynthControlnet:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet:
mask = mask.unsqueeze(2) mask = mask.unsqueeze(2)
mask = 1.0 - mask mask = 1.0 - mask
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
else:
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,) return (model_patched,)

View File

@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io from comfy_api.latest import io, _io
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -76,7 +76,7 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal) is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {}
if is_v3: if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else: else:
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG": if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys, hidden_inputs_v3 v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
raise exc raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results] return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if is_class(obj): if is_class(obj):
type_obj = obj type_obj = obj
obj.VALIDATE_CLASS() obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields # otherwise, use class instance to populate/reuse some fields
else: else:
type_obj = type(obj) type_obj = type(obj)
type_obj.VALIDATE_CLASS() type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone) f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# V1 # V1
else: else:
f = getattr(obj, func) f = getattr(obj, func)
@ -320,8 +325,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task: if has_pending_task:
return return_values, {}, False, has_pending_task return return_values, {}, False, has_pending_task
@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index): def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id) unblock = execution_list.add_external_block(unique_id)
@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated):
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = [] errors = []
valid = True valid = True
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal): if issubclass(obj_class, _ComfyNodeInternal):
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
validate_function_name = "validate_inputs" validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name) validate_function = first_real_override(obj_class, validate_function_name)
else: else:
class_inputs = obj_class.INPUT_TYPES()
validate_function_name = "VALIDATE_INPUTS" validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None) validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None: if validate_function is not None:
@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
validate_has_kwargs = argspec.varkw is not None validate_has_kwargs = argspec.varkw is not None
received_types = {} received_types = {}
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
for x in valid_inputs: for x in valid_inputs:
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None assert extra_info is not None
@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue continue
if len(validate_function_inputs) > 0 or validate_has_kwargs: if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {} input_filtered = {}
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs: if x in validate_function_inputs or validate_has_kwargs:
@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret) ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):

View File

@ -2355,6 +2355,7 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py", "nodes_easycache.py",
"nodes_audio_encoder.py", "nodes_audio_encoder.py",
"nodes_rope.py", "nodes_rope.py",
"nodes_logic.py",
"nodes_nop.py", "nodes_nop.py",
] ]

View File

@ -98,7 +98,7 @@ def create_cors_middleware(allowed_origin: str):
response = await handler(request) response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true' response.headers['Access-Control-Allow-Credentials'] = 'true'
return response return response