mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-08 20:42:32 +08:00
Compare commits
8 Commits
635bfcf47e
...
0ef5b22172
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ef5b22172 | ||
|
|
ed6002cb60 | ||
|
|
bc72d7f8d1 | ||
|
|
aef4e13588 | ||
|
|
4e6a1b66a9 | ||
|
|
9cf299a9f9 | ||
|
|
e89b22993a | ||
|
|
55bd606e92 |
@ -8,6 +8,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@ -181,6 +182,7 @@ class Flux(SD3):
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
@ -593,6 +595,7 @@ class Wan22(Wan21):
|
||||
class HunyuanImage21(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 2
|
||||
spacial_downscale_ratio = 32
|
||||
scale_factor = 0.75289
|
||||
|
||||
latent_rgb_factors = [
|
||||
@ -726,6 +729,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@ -750,6 +754,7 @@ class ACEAudio(LatentFormat):
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
|
||||
@ -18,12 +18,12 @@ class CompressedTimestep:
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
# audio to video cross attention
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
self._padding = 2 * self.padding[0]
|
||||
self.padding = (0, self.padding[1], self.padding[2])
|
||||
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
if cache_x is None and x.shape[2] == 1:
|
||||
#Fast path - the op will pad for use by truncating the weight
|
||||
#and save math on a pile of zeros.
|
||||
return super().forward(x, autopad="causal_zero")
|
||||
|
||||
if self._padding > 0:
|
||||
padding_needed = self._padding
|
||||
if cache_x is not None:
|
||||
cache_x = cache_x.to(x.device)
|
||||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||||
padding_shape = list(x.shape)
|
||||
padding_shape[2] = padding_needed
|
||||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
|
||||
for k in sdk:
|
||||
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||
if k.endswith(".weight") and ".linear1." in k:
|
||||
|
||||
10
comfy/ops.py
10
comfy/ops.py
@ -203,7 +203,9 @@ class disable_weight_init:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
|
||||
if autopad == "causal_zero":
|
||||
weight = weight[:, :, -input.shape[2]:, :, :]
|
||||
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||
if bias is not None:
|
||||
@ -212,15 +214,15 @@ class disable_weight_init:
|
||||
else:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
def forward_comfy_cast_weights(self, input, autopad=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
x = self._conv_forward(input, weight, bias, autopad=autopad)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if torch.count_nonzero(latent_image) == 0:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if downscale_ratio_spacial is not None:
|
||||
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
|
||||
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
|
||||
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
|
||||
|
||||
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
||||
latent_image = latent_image.unsqueeze(2)
|
||||
return latent_image
|
||||
|
||||
@ -1383,6 +1383,8 @@ class Schema:
|
||||
"""Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph."""
|
||||
enable_expand: bool=False
|
||||
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
||||
accept_all_inputs: bool=False
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@ -1853,6 +1855,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._NOT_IDEMPOTENT
|
||||
|
||||
_ACCEPT_ALL_INPUTS = None
|
||||
@final
|
||||
@classproperty
|
||||
def ACCEPT_ALL_INPUTS(cls): # noqa
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._ACCEPT_ALL_INPUTS
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
@ -1891,6 +1901,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._INPUT_IS_LIST = schema.is_input_list
|
||||
if cls._NOT_IDEMPOTENT is None:
|
||||
cls._NOT_IDEMPOTENT = schema.not_idempotent
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
|
||||
66
comfy_api_nodes/apis/hunyuan3d.py
Normal file
66
comfy_api_nodes/apis/hunyuan3d.py
Normal file
@ -0,0 +1,66 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class InputGenerateType(TypedDict):
|
||||
generate_type: str
|
||||
polygon_type: str
|
||||
pbr: bool
|
||||
|
||||
|
||||
class Hunyuan3DViewImage(BaseModel):
|
||||
ViewType: str = Field(..., description="Valid values: back, left, right.")
|
||||
ViewImageUrl: str = Field(...)
|
||||
|
||||
|
||||
class To3DProTaskRequest(BaseModel):
|
||||
Model: str = Field(...)
|
||||
Prompt: str | None = Field(None)
|
||||
ImageUrl: str | None = Field(None)
|
||||
MultiViewImages: list[Hunyuan3DViewImage] | None = Field(None)
|
||||
EnablePBR: bool | None = Field(...)
|
||||
FaceCount: int | None = Field(...)
|
||||
GenerateType: str | None = Field(...)
|
||||
PolygonType: str | None = Field(...)
|
||||
|
||||
|
||||
class RequestError(BaseModel):
|
||||
Code: str = Field("")
|
||||
Message: str = Field("")
|
||||
|
||||
|
||||
class To3DProTaskCreateResponse(BaseModel):
|
||||
JobId: str | None = Field(None)
|
||||
Error: RequestError | None = Field(None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def unwrap_data(cls, values: dict) -> dict:
|
||||
if "Response" in values and isinstance(values["Response"], dict):
|
||||
return values["Response"]
|
||||
return values
|
||||
|
||||
|
||||
class ResultFile3D(BaseModel):
|
||||
Type: str = Field(...)
|
||||
Url: str = Field(...)
|
||||
PreviewImageUrl: str = Field("")
|
||||
|
||||
|
||||
class To3DProTaskResultResponse(BaseModel):
|
||||
ErrorCode: str = Field("")
|
||||
ErrorMessage: str = Field("")
|
||||
ResultFile3Ds: list[ResultFile3D] = Field([])
|
||||
Status: str = Field(...)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def unwrap_data(cls, values: dict) -> dict:
|
||||
if "Response" in values and isinstance(values["Response"], dict):
|
||||
return values["Response"]
|
||||
return values
|
||||
|
||||
|
||||
class To3DProTaskQueryRequest(BaseModel):
|
||||
JobId: str = Field(...)
|
||||
297
comfy_api_nodes/nodes_hunyuan3d.py
Normal file
297
comfy_api_nodes/nodes_hunyuan3d.py
Normal file
@ -0,0 +1,297 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.hunyuan3d import (
|
||||
Hunyuan3DViewImage,
|
||||
InputGenerateType,
|
||||
ResultFile3D,
|
||||
To3DProTaskCreateResponse,
|
||||
To3DProTaskQueryRequest,
|
||||
To3DProTaskRequest,
|
||||
To3DProTaskResultResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D:
|
||||
for i in response_objs:
|
||||
if i.Type.lower() == "glb":
|
||||
return i
|
||||
raise ValueError("No GLB file found in response. Please report this to the developers.")
|
||||
|
||||
|
||||
class TencentTextToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model (Pro)",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["3.0", "3.1"],
|
||||
tooltip="The LowPoly option is unavailable for the `3.1` model.",
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
|
||||
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
|
||||
IO.DynamicCombo.Input(
|
||||
"generate_type",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]),
|
||||
IO.DynamicCombo.Option(
|
||||
"LowPoly",
|
||||
[
|
||||
IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]),
|
||||
IO.Boolean.Input("pbr", default=False),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("Geometry", []),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["generate_type", "generate_type.pbr", "face_count"]),
|
||||
expr="""
|
||||
(
|
||||
$base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15;
|
||||
$pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0;
|
||||
$face := widgets.face_count != 500000 ? 10 : 0;
|
||||
{"type":"usd","usd": ($base + $pbr + $face) * 0.02}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
face_count: int,
|
||||
generate_type: InputGenerateType,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
|
||||
if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly":
|
||||
raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DProTaskRequest(
|
||||
Model=model,
|
||||
Prompt=prompt,
|
||||
FaceCount=face_count,
|
||||
GenerateType=generate_type["generate_type"],
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentImageToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model (Pro)",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["3.0", "3.1"],
|
||||
tooltip="The LowPoly option is unavailable for the `3.1` model.",
|
||||
),
|
||||
IO.Image.Input("image"),
|
||||
IO.Image.Input("image_left", optional=True),
|
||||
IO.Image.Input("image_right", optional=True),
|
||||
IO.Image.Input("image_back", optional=True),
|
||||
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
|
||||
IO.DynamicCombo.Input(
|
||||
"generate_type",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]),
|
||||
IO.DynamicCombo.Option(
|
||||
"LowPoly",
|
||||
[
|
||||
IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]),
|
||||
IO.Boolean.Input("pbr", default=False),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("Geometry", []),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["generate_type", "generate_type.pbr", "face_count"],
|
||||
inputs=["image_left", "image_right", "image_back"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15;
|
||||
$multiview := (
|
||||
inputs.image_left.connected or inputs.image_right.connected or inputs.image_back.connected
|
||||
) ? 10 : 0;
|
||||
$pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0;
|
||||
$face := widgets.face_count != 500000 ? 10 : 0;
|
||||
{"type":"usd","usd": ($base + $multiview + $pbr + $face) * 0.02}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
face_count: int,
|
||||
generate_type: InputGenerateType,
|
||||
seed: int,
|
||||
image_left: Input.Image | None = None,
|
||||
image_right: Input.Image | None = None,
|
||||
image_back: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly":
|
||||
raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.")
|
||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||
multiview_images = []
|
||||
for k, v in {
|
||||
"left": image_left,
|
||||
"right": image_right,
|
||||
"back": image_back,
|
||||
}.items():
|
||||
if v is None:
|
||||
continue
|
||||
validate_image_dimensions(v, min_width=128, min_height=128)
|
||||
multiview_images.append(
|
||||
Hunyuan3DViewImage(
|
||||
ViewType=k,
|
||||
ViewImageUrl=await upload_image_to_comfyapi(
|
||||
cls,
|
||||
downscale_image_tensor_by_max_side(v, max_side=4900),
|
||||
mime_type="image/webp",
|
||||
total_pixels=24_010_000,
|
||||
),
|
||||
)
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DProTaskRequest(
|
||||
Model=model,
|
||||
FaceCount=face_count,
|
||||
GenerateType=generate_type["generate_type"],
|
||||
ImageUrl=await upload_image_to_comfyapi(
|
||||
cls,
|
||||
downscale_image_tensor_by_max_side(image, max_side=4900),
|
||||
mime_type="image/webp",
|
||||
total_pixels=24_010_000,
|
||||
),
|
||||
MultiViewImages=multiview_images if multiview_images else None,
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentHunyuan3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> TencentHunyuan3DExtension:
|
||||
return TencentHunyuan3DExtension()
|
||||
@ -249,7 +249,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
max_poll_attempts=160,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
@ -149,7 +149,6 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
response_model=Sora2GenerationResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
|
||||
@ -203,7 +203,6 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: x.credits * 0.08,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
estimated_duration=60,
|
||||
)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from .conversions import (
|
||||
bytesio_to_image_tensor,
|
||||
convert_mask_to_image,
|
||||
downscale_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
@ -33,6 +34,7 @@ from .download_helpers import (
|
||||
from .upload_helpers import (
|
||||
upload_audio_to_comfyapi,
|
||||
upload_file_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
@ -61,6 +63,7 @@ __all__ = [
|
||||
# Upload helpers
|
||||
"upload_audio_to_comfyapi",
|
||||
"upload_file_to_comfyapi",
|
||||
"upload_image_to_comfyapi",
|
||||
"upload_images_to_comfyapi",
|
||||
"upload_video_to_comfyapi",
|
||||
# Download helpers
|
||||
@ -75,6 +78,7 @@ __all__ = [
|
||||
"bytesio_to_image_tensor",
|
||||
"convert_mask_to_image",
|
||||
"downscale_image_tensor",
|
||||
"downscale_image_tensor_by_max_side",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"resize_mask_to_image",
|
||||
|
||||
@ -141,7 +141,7 @@ async def poll_op(
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
@ -238,7 +238,7 @@ async def poll_op_raw(
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: dict[str, Any] | BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
|
||||
@ -144,6 +144,21 @@ def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024)
|
||||
return s
|
||||
|
||||
|
||||
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
|
||||
"""Downscale input image tensor so the largest dimension is at most max_side pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
height, width = samples.shape[2], samples.shape[3]
|
||||
max_dim = max(width, height)
|
||||
if max_dim <= max_side:
|
||||
return image
|
||||
scale_by = max_side / max_dim
|
||||
new_width = round(width * scale_by)
|
||||
new_height = round(height * scale_by)
|
||||
s = common_upscale(samples, new_width, new_height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
|
||||
@ -88,6 +88,28 @@ async def upload_images_to_comfyapi(
|
||||
return download_urls
|
||||
|
||||
|
||||
async def upload_image_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
total_pixels: int = 2048 * 2048,
|
||||
) -> str:
|
||||
"""Uploads a single image to ComfyUI API and returns its download URL."""
|
||||
return (
|
||||
await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type=mime_type,
|
||||
wait_label=wait_label,
|
||||
show_batch_index=False,
|
||||
total_pixels=total_pixels,
|
||||
)
|
||||
)[0]
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
audio: Input.Audio,
|
||||
|
||||
@ -741,7 +741,7 @@ class SamplerCustom(io.ComfyNode):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
latent = latent.copy()
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
|
||||
latent["samples"] = latent_image
|
||||
|
||||
if not add_noise:
|
||||
@ -760,6 +760,7 @@ class SamplerCustom(io.ComfyNode):
|
||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||
@ -939,7 +940,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
latent = latent.copy()
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
|
||||
latent["samples"] = latent_image
|
||||
|
||||
noise_mask = None
|
||||
@ -954,6 +955,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
||||
|
||||
@ -104,19 +104,23 @@ class CustomComboNode(io.ComfyNode):
|
||||
category="utils",
|
||||
is_experimental=True,
|
||||
inputs=[io.Combo.Input("choice", options=[])],
|
||||
outputs=[io.String.Output()]
|
||||
outputs=[
|
||||
io.String.Output(display_name="STRING"),
|
||||
io.Int.Output(display_name="INDEX"),
|
||||
],
|
||||
accept_all_inputs=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, choice: io.Combo.Type) -> bool:
|
||||
def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool:
|
||||
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
|
||||
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
|
||||
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, choice: io.Combo.Type) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice)
|
||||
def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput:
|
||||
return io.NodeOutput(choice, index)
|
||||
|
||||
|
||||
class DCTestNode(io.ComfyNode):
|
||||
|
||||
@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples":latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
@ -175,7 +175,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
continue
|
||||
obj = cached.outputs[output_index]
|
||||
input_data_all[x] = obj
|
||||
elif input_category is not None:
|
||||
elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS):
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if is_v3:
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -1241,7 +1241,7 @@ class EmptyLatentImage:
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||
return ({"samples":latent}, )
|
||||
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
||||
|
||||
|
||||
class LatentFromBatch:
|
||||
@ -1549,7 +1549,7 @@ class SetLatentNoiseMask:
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
latent_image = latent["samples"]
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
|
||||
|
||||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
@ -1567,6 +1567,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user