mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 08:40:19 +08:00
Merge branch 'pysssss/combo-hidden-index-output' into pysssss/basic-glsl-shader-node
This commit is contained in:
commit
521ca3b5d2
@ -18,12 +18,12 @@ class CompressedTimestep:
|
|||||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||||
"""
|
"""
|
||||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||||
"""
|
"""
|
||||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||||
|
|
||||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||||
self.patches_per_frame = patches_per_frame
|
self.patches_per_frame = patches_per_frame
|
||||||
self.num_frames = num_tokens // patches_per_frame
|
self.num_frames = num_tokens // patches_per_frame
|
||||||
|
|
||||||
@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
return (*scale_shift_ada_values, *gate_ada_values)
|
return (*scale_shift_ada_values, *gate_ada_values)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
x: Tuple[torch.Tensor, torch.Tensor],
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_context=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||||
a_context=None,
|
|
||||||
attention_mask=None,
|
|
||||||
v_timestep=None,
|
|
||||||
a_timestep=None,
|
|
||||||
v_pe=None,
|
|
||||||
a_pe=None,
|
|
||||||
v_cross_pe=None,
|
|
||||||
a_cross_pe=None,
|
|
||||||
v_cross_scale_shift_timestep=None,
|
|
||||||
a_cross_scale_shift_timestep=None,
|
|
||||||
v_cross_gate_timestep=None,
|
|
||||||
a_cross_gate_timestep=None,
|
|
||||||
transformer_options=None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||||
|
|
||||||
|
# video
|
||||||
if run_vx:
|
if run_vx:
|
||||||
vshift_msa, vscale_msa, vgate_msa = (
|
# video self-attention
|
||||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||||
)
|
|
||||||
|
|
||||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
del vshift_msa, vscale_msa
|
||||||
vx += self.attn2(
|
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||||
comfy.ldm.common_dit.rms_norm(vx),
|
del norm_vx
|
||||||
context=v_context,
|
# video cross-attention
|
||||||
mask=attention_mask,
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
transformer_options=transformer_options,
|
vx.addcmul_(attn1_out, vgate_msa)
|
||||||
)
|
del vgate_msa, attn1_out
|
||||||
|
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
del vshift_msa, vscale_msa, vgate_msa
|
|
||||||
|
|
||||||
|
# audio
|
||||||
if run_ax:
|
if run_ax:
|
||||||
ashift_msa, ascale_msa, agate_msa = (
|
# audio self-attention
|
||||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||||
)
|
|
||||||
|
|
||||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||||
ax += (
|
del ashift_msa, ascale_msa
|
||||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||||
* agate_msa
|
del norm_ax
|
||||||
)
|
# audio cross-attention
|
||||||
ax += self.audio_attn2(
|
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||||
comfy.ldm.common_dit.rms_norm(ax),
|
ax.addcmul_(attn1_out, agate_msa)
|
||||||
context=a_context,
|
del agate_msa, attn1_out
|
||||||
mask=attention_mask,
|
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
transformer_options=transformer_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
del ashift_msa, ascale_msa, agate_msa
|
# video - audio cross attention.
|
||||||
|
|
||||||
# Audio - Video cross attention.
|
|
||||||
if run_a2v or run_v2a:
|
if run_a2v or run_v2a:
|
||||||
# norm3
|
|
||||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||||
|
|
||||||
(
|
# audio to video cross attention
|
||||||
scale_ca_audio_hidden_states_a2v,
|
|
||||||
shift_ca_audio_hidden_states_a2v,
|
|
||||||
scale_ca_audio_hidden_states_v2a,
|
|
||||||
shift_ca_audio_hidden_states_v2a,
|
|
||||||
gate_out_v2a,
|
|
||||||
) = self.get_av_ca_ada_values(
|
|
||||||
self.scale_shift_table_a2v_ca_audio,
|
|
||||||
ax.shape[0],
|
|
||||||
a_cross_scale_shift_timestep,
|
|
||||||
a_cross_gate_timestep,
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
scale_ca_video_hidden_states_a2v,
|
|
||||||
shift_ca_video_hidden_states_a2v,
|
|
||||||
scale_ca_video_hidden_states_v2a,
|
|
||||||
shift_ca_video_hidden_states_v2a,
|
|
||||||
gate_out_a2v,
|
|
||||||
) = self.get_av_ca_ada_values(
|
|
||||||
self.scale_shift_table_a2v_ca_video,
|
|
||||||
vx.shape[0],
|
|
||||||
v_cross_scale_shift_timestep,
|
|
||||||
v_cross_gate_timestep,
|
|
||||||
)
|
|
||||||
|
|
||||||
if run_a2v:
|
if run_a2v:
|
||||||
vx_scaled = (
|
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||||
+ shift_ca_video_hidden_states_a2v
|
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||||
)
|
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||||
ax_scaled = (
|
|
||||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
|
||||||
+ shift_ca_audio_hidden_states_a2v
|
|
||||||
)
|
|
||||||
vx += (
|
|
||||||
self.audio_to_video_attn(
|
|
||||||
vx_scaled,
|
|
||||||
context=ax_scaled,
|
|
||||||
pe=v_cross_pe,
|
|
||||||
k_pe=a_cross_pe,
|
|
||||||
transformer_options=transformer_options,
|
|
||||||
)
|
|
||||||
* gate_out_a2v
|
|
||||||
)
|
|
||||||
|
|
||||||
del gate_out_a2v
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||||
del scale_ca_video_hidden_states_a2v,\
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||||
shift_ca_video_hidden_states_a2v,\
|
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||||
scale_ca_audio_hidden_states_a2v,\
|
|
||||||
shift_ca_audio_hidden_states_a2v,\
|
|
||||||
|
|
||||||
|
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||||
|
del vx_scaled, ax_scaled
|
||||||
|
|
||||||
|
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||||
|
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||||
|
del gate_out_a2v, a2v_out
|
||||||
|
|
||||||
|
# video to audio cross attention
|
||||||
if run_v2a:
|
if run_v2a:
|
||||||
ax_scaled = (
|
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||||
+ shift_ca_audio_hidden_states_v2a
|
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||||
)
|
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||||
vx_scaled = (
|
|
||||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
|
||||||
+ shift_ca_video_hidden_states_v2a
|
|
||||||
)
|
|
||||||
ax += (
|
|
||||||
self.video_to_audio_attn(
|
|
||||||
ax_scaled,
|
|
||||||
context=vx_scaled,
|
|
||||||
pe=a_cross_pe,
|
|
||||||
k_pe=v_cross_pe,
|
|
||||||
transformer_options=transformer_options,
|
|
||||||
)
|
|
||||||
* gate_out_v2a
|
|
||||||
)
|
|
||||||
|
|
||||||
del gate_out_v2a
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||||
del scale_ca_video_hidden_states_v2a,\
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||||
shift_ca_video_hidden_states_v2a,\
|
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||||
scale_ca_audio_hidden_states_v2a,\
|
|
||||||
shift_ca_audio_hidden_states_v2a
|
|
||||||
|
|
||||||
|
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||||
|
del ax_scaled, vx_scaled
|
||||||
|
|
||||||
|
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||||
|
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||||
|
del gate_out_v2a, v2a_out
|
||||||
|
|
||||||
|
del vx_norm3, ax_norm3
|
||||||
|
|
||||||
|
# video feedforward
|
||||||
if run_vx:
|
if run_vx:
|
||||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
|
||||||
)
|
|
||||||
|
|
||||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||||
vx += self.ff(vx_scaled) * vgate_mlp
|
del vshift_mlp, vscale_mlp
|
||||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
|
||||||
|
|
||||||
|
ff_out = self.ff(vx_scaled)
|
||||||
|
del vx_scaled
|
||||||
|
|
||||||
|
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||||
|
vx.addcmul_(ff_out, vgate_mlp)
|
||||||
|
del vgate_mlp, ff_out
|
||||||
|
|
||||||
|
# audio feedforward
|
||||||
if run_ax:
|
if run_ax:
|
||||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
|
||||||
)
|
|
||||||
|
|
||||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
del ashift_mlp, ascale_mlp
|
||||||
|
|
||||||
del ashift_mlp, ascale_mlp, agate_mlp
|
ff_out = self.audio_ff(ax_scaled)
|
||||||
|
del ax_scaled
|
||||||
|
|
||||||
|
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||||
|
ax.addcmul_(ff_out, agate_mlp)
|
||||||
|
del agate_mlp, ff_out
|
||||||
|
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|
||||||
@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
|
|||||||
audio_length = kwargs.get("audio_length", 0)
|
audio_length = kwargs.get("audio_length", 0)
|
||||||
# Separate audio and video latents
|
# Separate audio and video latents
|
||||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||||
|
|
||||||
|
has_spatial_mask = False
|
||||||
|
if denoise_mask is not None:
|
||||||
|
# check if any frame has spatial variation (inpainting)
|
||||||
|
for frame_idx in range(denoise_mask.shape[2]):
|
||||||
|
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||||
|
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||||
|
has_spatial_mask = True
|
||||||
|
break
|
||||||
|
|
||||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||||
)
|
)
|
||||||
|
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||||
|
|
||||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
ax = self.audio_patchify_proj(ax)
|
ax = self.audio_patchify_proj(ax)
|
||||||
@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
|
|||||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||||
orig_shape = kwargs.get("orig_shape")
|
orig_shape = kwargs.get("orig_shape")
|
||||||
|
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||||
v_patches_per_frame = None
|
v_patches_per_frame = None
|
||||||
if orig_shape is not None and len(orig_shape) == 5:
|
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||||
|
|
||||||
@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||||
|
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||||
cross_av_timestep_ss = [
|
cross_av_timestep_ss = [
|
||||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||||
if k.endswith(".weight") and ".linear1." in k:
|
if k.endswith(".weight") and ".linear1." in k:
|
||||||
|
|||||||
@ -1344,7 +1344,8 @@ class Schema:
|
|||||||
"""The category of the node, as per the "Add Node" menu."""
|
"""The category of the node, as per the "Add Node" menu."""
|
||||||
inputs: list[Input] = field(default_factory=list)
|
inputs: list[Input] = field(default_factory=list)
|
||||||
outputs: list[Output] = field(default_factory=list)
|
outputs: list[Output] = field(default_factory=list)
|
||||||
hidden: list[Hidden] = field(default_factory=list)
|
hidden: list[Hidden | str] = field(default_factory=list)
|
||||||
|
"""Hidden inputs. Use Hidden enum for system values (PROMPT, UNIQUE_ID, etc.) or plain strings for custom frontend-provided values."""
|
||||||
description: str=""
|
description: str=""
|
||||||
"""Node description, shown as a tooltip when hovering over the node."""
|
"""Node description, shown as a tooltip when hovering over the node."""
|
||||||
search_aliases: list[str] = field(default_factory=list)
|
search_aliases: list[str] = field(default_factory=list)
|
||||||
@ -1443,7 +1444,10 @@ class Schema:
|
|||||||
input = create_input_dict_v1(self.inputs)
|
input = create_input_dict_v1(self.inputs)
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
if isinstance(hidden, str):
|
||||||
|
input.setdefault("hidden", {})[hidden] = (hidden,)
|
||||||
|
else:
|
||||||
|
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||||
# create separate lists from output fields
|
# create separate lists from output fields
|
||||||
output = []
|
output = []
|
||||||
output_is_list = []
|
output_is_list = []
|
||||||
@ -1504,7 +1508,10 @@ class Schema:
|
|||||||
add_to_dict_v3(output, output_dict)
|
add_to_dict_v3(output, output_dict)
|
||||||
if self.hidden:
|
if self.hidden:
|
||||||
for hidden in self.hidden:
|
for hidden in self.hidden:
|
||||||
hidden_list.append(hidden.value)
|
if isinstance(hidden, str):
|
||||||
|
hidden_list.append(hidden)
|
||||||
|
else:
|
||||||
|
hidden_list.append(hidden.value)
|
||||||
|
|
||||||
info = NodeInfoV3(
|
info = NodeInfoV3(
|
||||||
input=input_dict,
|
input=input_dict,
|
||||||
|
|||||||
@ -104,7 +104,11 @@ class CustomComboNode(io.ComfyNode):
|
|||||||
category="utils",
|
category="utils",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[io.Combo.Input("choice", options=[])],
|
inputs=[io.Combo.Input("choice", options=[])],
|
||||||
outputs=[io.String.Output()]
|
outputs=[
|
||||||
|
io.String.Output(display_name="STRING"),
|
||||||
|
io.Int.Output(display_name="INDEX"),
|
||||||
|
],
|
||||||
|
hidden=["index"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -115,8 +119,8 @@ class CustomComboNode(io.ComfyNode):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
def execute(cls, choice: io.Combo.Type, index: int = 0) -> io.NodeOutput:
|
||||||
return io.NodeOutput(choice)
|
return io.NodeOutput(choice, index)
|
||||||
|
|
||||||
|
|
||||||
class DCTestNode(io.ComfyNode):
|
class DCTestNode(io.ComfyNode):
|
||||||
|
|||||||
@ -192,6 +192,11 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||||
|
# Handle custom hidden inputs from prompt data
|
||||||
|
system_hidden_names = {h.name for h in io.Hidden}
|
||||||
|
for hidden_name in hidden:
|
||||||
|
if hidden_name not in system_hidden_names and hidden_name in inputs:
|
||||||
|
input_data_all[hidden_name] = [inputs[hidden_name]]
|
||||||
else:
|
else:
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
h = valid_inputs["hidden"]
|
h = valid_inputs["hidden"]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user