continue building nodes / testing vae

This commit is contained in:
Yousef Rafat 2025-12-07 21:41:14 +02:00
parent 041dbd6a8a
commit 4b9332cc21
3 changed files with 378 additions and 56 deletions

View File

@ -1141,11 +1141,6 @@ def repeat(
kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
@dataclass
class NaDiTOutput:
vid_sample: torch.Tensor
class NaDiT(nn.Module): class NaDiT(nn.Module):
def __init__( def __init__(
@ -1246,26 +1241,32 @@ class NaDiT(nn.Module):
"mmdit_stwin_3d_spatial", "mmdit_stwin_3d_spatial",
] ]
def set_gradient_checkpointing(self, enable: bool):
self.gradient_checkpointing = enable
def forward( def forward(
self, self,
vid: torch.FloatTensor, # l c x,
txt: torch.FloatTensor, # l c timestep,
vid_shape: torch.LongTensor, # b 3 context, # l c
txt_shape: torch.LongTensor, # b 1 txt_shape, # b 1
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b disable_cache: bool = True, # for test # TODO ?
disable_cache: bool = True, # for test ):
): pos_cond, neg_cond = context.chunk(2, dim=0)
# Text input. pos_cond, pos_shape = flatten(pos_cond)
neg_cond, neg_shape = flatten(neg_cond)
diff = abs(pos_shape.shape[1] - neg_shape.shape[1])
if pos_shape.shape[1] > neg_shape.shape[1]:
neg_shape = F.pad(neg_shape, (0, 0, 0, diff))
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
else:
pos_shape = F.pad(pos_shape, (0, 0, 0, diff))
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
vid = x
txt = context
vid, vid_shape = flatten(x)
if txt_shape.size(-1) == 1 and self.need_txt_repeat: if txt_shape.size(-1) == 1 and self.need_txt_repeat:
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
# slice vid after patching in when using sequence parallelism # slice vid after patching in when using sequence parallelism
txt = self.txt_in(txt) txt = self.txt_in(txt)
# Video input.
# Sequence parallel slicing is done inside patching class.
vid, vid_shape = self.vid_in(vid, vid_shape) vid, vid_shape = self.vid_in(vid, vid_shape)
# Embedding input. # Embedding input.
@ -1284,4 +1285,5 @@ class NaDiT(nn.Module):
) )
vid, vid_shape = self.vid_out(vid, vid_shape, cache) vid, vid_shape = self.vid_out(vid, vid_shape, cache)
return NaDiTOutput(vid_sample=vid) vid = unflatten(vid, vid_shape)
return vid

View File

@ -4,11 +4,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.models.upsampling import Upsample2D
from einops import rearrange from einops import rearrange
from model import safe_pad_operation from model import safe_pad_operation
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
from comfy.ldm.modules.attention import optimized_attention
class SpatialNorm(nn.Module): class SpatialNorm(nn.Module):
def __init__( def __init__(
@ -28,6 +28,259 @@ class SpatialNorm(nn.Module):
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f return new_f
# partial implementation of diffusers's Attention for comfyui
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
is_causal: bool = False,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else:
self.group_norm = None
if spatial_norm_dim is not None:
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
else:
self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
else:
self.add_q_proj = None
self.add_k_proj = None
self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
else:
self.to_add_out = None
self.norm_added_q = None
self.norm_added_k = None
def __call__(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if self.spatial_norm is not None:
hidden_states = self.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if self.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str):
"""
Inflate a 2D convolution weight matrix to a 3D one.
Parameters:
weight_2d: The weight matrix of 2D conv to be inflated.
weight_3d: The weight matrix of 3D conv to be initialized.
inflation_mode: the mode of inflation
"""
assert inflation_mode in ["tail", "replicate"]
assert weight_3d.shape[:2] == weight_2d.shape[:2]
with torch.no_grad():
if inflation_mode == "replicate":
depth = weight_3d.size(2)
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
else:
weight_3d.fill_(0.0)
weight_3d[:, :, -1].copy_(weight_2d)
return weight_3d
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str):
"""
Inflate a 2D convolution bias tensor to a 3D one
Parameters:
bias_2d: The bias tensor of 2D conv to be inflated.
bias_3d: The bias tensor of 3D conv to be initialized.
inflation_mode: Placeholder to align `inflate_weight`.
"""
assert bias_3d.shape == bias_2d.shape
with torch.no_grad():
bias_3d.copy_(bias_2d)
return bias_3d
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
"""
the main function to inflated 2D parameters to 3D.
"""
weight_name = prefix + "weight"
bias_name = prefix + "bias"
if weight_name in state_dict:
weight_2d = state_dict[weight_name]
if weight_2d.dim() == 4:
# Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
weight_3d = inflate_weight_fn(
weight_2d=weight_2d,
weight_3d=layer.weight,
inflation_mode=layer.inflation_mode,
)
state_dict[weight_name] = weight_3d
else:
return state_dict
# It's a 3d state dict, should not do inflation on both bias and weight.
if bias_name in state_dict:
bias_2d = state_dict[bias_name]
if bias_2d.dim() == 1:
# Assuming the 2D biases are 1D tensors (out_channels,)
bias_3d = inflate_bias_fn(
bias_2d=bias_2d,
bias_3d=layer.bias,
inflation_mode=layer.inflation_mode,
)
state_dict[bias_name] = bias_3d
return state_dict
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype input_dtype = x.dtype
if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)):
@ -131,15 +384,14 @@ class InflatedCausalConv3d(nn.Conv3d):
def _load_from_state_dict( def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
): ):
# wirdly inflation_mode is pad, which would cause an assert error if self.inflation_mode != "none":
#if self.inflation_mode != "none": state_dict = modify_state_dict(
# state_dict = modify_state_dict( self,
# self, state_dict,
# state_dict, prefix,
# prefix, inflate_weight_fn=inflate_weight,
# inflate_weight_fn=inflate_weight, inflate_bias_fn=inflate_bias,
# inflate_bias_fn=inflate_bias, )
# )
super()._load_from_state_dict( super()._load_from_state_dict(
state_dict, state_dict,
prefix, prefix,
@ -287,7 +539,10 @@ class Downsample3D(nn.Module):
spatial_down: bool = False, spatial_down: bool = False,
temporal_down: bool = False, temporal_down: bool = False,
name: str = "conv", name: str = "conv",
kernel_size=3,
use_conv: bool = False,
padding = 1, padding = 1,
bias=True,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -295,7 +550,6 @@ class Downsample3D(nn.Module):
self.name = name self.name = name
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
conv = self.conv
self.temporal_down = temporal_down self.temporal_down = temporal_down
self.spatial_down = spatial_down self.spatial_down = spatial_down
@ -305,9 +559,7 @@ class Downsample3D(nn.Module):
self.temporal_kernel = 3 if temporal_down else 1 self.temporal_kernel = 3 if temporal_down else 1
self.spatial_kernel = 3 if spatial_down else 1 self.spatial_kernel = 3 if spatial_down else 1
if type(conv) in [nn.Conv2d]: if use_conv:
# Note: lora_layer is not passed into constructor in the original implementation.
# So we make a simplification.
conv = InflatedCausalConv3d( conv = InflatedCausalConv3d(
self.channels, self.channels,
self.out_channels, self.out_channels,
@ -320,20 +572,15 @@ class Downsample3D(nn.Module):
), ),
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
) )
elif type(conv) is nn.AvgPool2d: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = nn.AvgPool3d( conv = nn.AvgPool3d(
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
) )
else:
raise NotImplementedError self.conv = conv
if self.name == "conv":
self.Conv2d_0 = conv
self.conv = conv
else:
self.conv = conv
def forward( def forward(
self, self,
@ -386,6 +633,9 @@ class ResnetBlock3D(nn.Module):
super().__init__() super().__init__()
self.up = up self.up = up
self.down = down self.down = down
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.use_in_shortcut = use_in_shortcut self.use_in_shortcut = use_in_shortcut
self.output_scale_factor = output_scale_factor self.output_scale_factor = output_scale_factor
self.skip_time_act = skip_time_act self.skip_time_act = skip_time_act
@ -394,6 +644,12 @@ class ResnetBlock3D(nn.Module):
self.time_emb_proj = nn.Linear(temb_channels, out_channels) self.time_emb_proj = nn.Linear(temb_channels, out_channels)
else: else:
self.time_emb_proj = None self.time_emb_proj = None
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
if groups_out is None:
groups_out = groups
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.use_in_shortcut = self.in_channels != out_channels
self.dropout = torch.nn.Dropout(dropout)
self.conv1 = InflatedCausalConv3d( self.conv1 = InflatedCausalConv3d(
self.in_channels, self.in_channels,
self.out_channels, self.out_channels,
@ -405,7 +661,7 @@ class ResnetBlock3D(nn.Module):
self.conv2 = InflatedCausalConv3d( self.conv2 = InflatedCausalConv3d(
self.out_channels, self.out_channels,
self.conv2.out_channels, conv_2d_out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
@ -431,11 +687,11 @@ class ResnetBlock3D(nn.Module):
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = InflatedCausalConv3d( self.conv_shortcut = InflatedCausalConv3d(
self.in_channels, self.in_channels,
self.conv_shortcut.out_channels, conv_2d_out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias=(self.conv_shortcut.bias is not None), bias=True,
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
) )
@ -534,7 +790,6 @@ class DownEncoderBlock3D(nn.Module):
if add_downsample: if add_downsample:
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
[ [
# [Override] Replace module.
Downsample3D( Downsample3D(
out_channels, out_channels,
use_conv=True, use_conv=True,
@ -1049,8 +1304,6 @@ class VideoAutoencoderKL(nn.Module):
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock3D",),
up_block_types: Tuple[str] = ("UpDecoderBlock3D",),
layers_per_block: int = 2, layers_per_block: int = 2,
act_fn: str = "silu", act_fn: str = "silu",
latent_channels: int = 16, latent_channels: int = 16,
@ -1059,7 +1312,7 @@ class VideoAutoencoderKL(nn.Module):
temporal_scale_num: int = 2, temporal_scale_num: int = 2,
slicing_up_num: int = 0, slicing_up_num: int = 0,
gradient_checkpoint: bool = False, gradient_checkpoint: bool = False,
inflation_mode = "tail", inflation_mode = "pad",
time_receptive_field: _receptive_field_t = "full", time_receptive_field: _receptive_field_t = "full",
use_quant_conv: bool = False, use_quant_conv: bool = False,
use_post_quant_conv: bool = False, use_post_quant_conv: bool = False,
@ -1068,6 +1321,8 @@ class VideoAutoencoderKL(nn.Module):
): ):
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
block_out_channels = (128, 256, 512, 512) block_out_channels = (128, 256, 512, 512)
down_block_types = ("DownEncoderBlock3D",) * 4
up_block_types = ("UpDecoderBlock3D",) * 4
super().__init__() super().__init__()
# pass init params to Encoder # pass init params to Encoder
@ -1257,4 +1512,4 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
#set_norm_limit(norm_max_mem) #set_norm_limit(norm_max_mem)
for m in self.modules(): for m in self.modules():
if isinstance(m, InflatedCausalConv3d): if isinstance(m, InflatedCausalConv3d):
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))

View File

@ -1,6 +1,5 @@
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io, ui from comfy_api.latest import ComfyExtension, io
import torch import torch
import math import math
from einops import rearrange from einops import rearrange
@ -9,7 +8,51 @@ from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda, Normalize from torchvision.transforms import Lambda, Normalize
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
def expand_dims(tensor, ndim):
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
return tensor.reshape(shape)
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0
return cond
def timestep_transform(timesteps, latents_shapes):
vt = 4
vs = 8
frames = (latents_shapes[:, 0] - 1) * vt + 1
heights = latents_shapes[:, 1] * vs
widths = latents_shapes[:, 2] * vs
# Compute shift factor.
def get_lin_function(x1, y1, x2, y2):
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
shift = torch.where(
frames > 1,
vid_shift_fn(heights * widths * frames),
img_shift_fn(heights * widths),
)
# Shift timesteps.
T = 1000.0
timesteps = timesteps / T
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
timesteps = timesteps * T
return timesteps
def inter(x_0, x_T, t):
t = expand_dims(t, x_0.ndim)
T = 1000.0
B = lambda t: t / T
A = lambda t: 1 - (t / T)
return A(t) * x_0 + B(t) * x_T
def area_resize(image, max_area): def area_resize(image, max_area):
height, width = image.shape[-2:] height, width = image.shape[-2:]
@ -80,7 +123,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
images = normalize(images) images = normalize(images)
images = rearrange(images, "t c h w -> c t h w") images = rearrange(images, "t c h w -> c t h w")
images = cut_videos(images) images = cut_videos(images)
return return io.NodeOutput(images)
class SeedVR2Conditioning(io.ComfyNode): class SeedVR2Conditioning(io.ComfyNode):
@classmethod @classmethod
@ -93,16 +136,38 @@ class SeedVR2Conditioning(io.ComfyNode):
io.Conditioning.Input("text_negative_conditioning"), io.Conditioning.Input("text_negative_conditioning"),
io.Conditioning.Input("vae_conditioning") io.Conditioning.Input("vae_conditioning")
], ],
outputs=[io.Conditioning.Output("positive"), io.Conditioning.Output("negative")], outputs=[io.Conditioning.Output(display_name = "positive"),
io.Conditioning.Output(display_name = "negative"),
io.Latent.Output(display_name = "latent")],
) )
@classmethod @classmethod
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
# TODO # TODO: should do the flattening logic as with the original code
pos_cond = text_positive_conditioning[0][0] pos_cond = text_positive_conditioning[0][0]
neg_cond = text_negative_conditioning[0][0] neg_cond = text_negative_conditioning[0][0]
return io.NodeOutput() noises = [torch.randn_like(latent) for latent in vae_conditioning]
aug_noises = [torch.randn_like(latent) for latent in vae_conditioning]
cond_noise_scale = 0.0
t = (
torch.tensor([1000.0])
* cond_noise_scale
)
shape = torch.tensor(vae_conditioning.shape[1:])[None]
t = timestep_transform(t, shape)
cond = inter(vae_conditioning, aug_noises, t)
condition = get_conditions(noises, cond)
# TODO / FIXME
pos_cond = torch.cat([condition, pos_cond], dim = 0)
neg_cond = torch.cat([condition, neg_cond], dim = 0)
negative = [[pos_cond, {}]]
positive = [[neg_cond, {}]]
return io.NodeOutput(positive, negative, noises)
class SeedVRExtension(ComfyExtension): class SeedVRExtension(ComfyExtension):
@override @override
@ -113,4 +178,4 @@ class SeedVRExtension(ComfyExtension):
] ]
async def comfy_entrypoint() -> SeedVRExtension: async def comfy_entrypoint() -> SeedVRExtension:
return SeedVRExtension() return SeedVRExtension()