mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
continue building nodes / testing vae
This commit is contained in:
parent
041dbd6a8a
commit
4b9332cc21
@ -1141,11 +1141,6 @@ def repeat(
|
|||||||
kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
|
kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
|
||||||
return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
|
return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NaDiTOutput:
|
|
||||||
vid_sample: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class NaDiT(nn.Module):
|
class NaDiT(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1246,26 +1241,32 @@ class NaDiT(nn.Module):
|
|||||||
"mmdit_stwin_3d_spatial",
|
"mmdit_stwin_3d_spatial",
|
||||||
]
|
]
|
||||||
|
|
||||||
def set_gradient_checkpointing(self, enable: bool):
|
|
||||||
self.gradient_checkpointing = enable
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
vid: torch.FloatTensor, # l c
|
x,
|
||||||
txt: torch.FloatTensor, # l c
|
timestep,
|
||||||
vid_shape: torch.LongTensor, # b 3
|
context, # l c
|
||||||
txt_shape: torch.LongTensor, # b 1
|
txt_shape, # b 1
|
||||||
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
|
disable_cache: bool = True, # for test # TODO ?
|
||||||
disable_cache: bool = True, # for test
|
|
||||||
):
|
):
|
||||||
# Text input.
|
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||||
|
pos_cond, pos_shape = flatten(pos_cond)
|
||||||
|
neg_cond, neg_shape = flatten(neg_cond)
|
||||||
|
diff = abs(pos_shape.shape[1] - neg_shape.shape[1])
|
||||||
|
if pos_shape.shape[1] > neg_shape.shape[1]:
|
||||||
|
neg_shape = F.pad(neg_shape, (0, 0, 0, diff))
|
||||||
|
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||||
|
else:
|
||||||
|
pos_shape = F.pad(pos_shape, (0, 0, 0, diff))
|
||||||
|
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||||
|
vid = x
|
||||||
|
txt = context
|
||||||
|
vid, vid_shape = flatten(x)
|
||||||
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
||||||
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
||||||
# slice vid after patching in when using sequence parallelism
|
# slice vid after patching in when using sequence parallelism
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
# Video input.
|
|
||||||
# Sequence parallel slicing is done inside patching class.
|
|
||||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||||
|
|
||||||
# Embedding input.
|
# Embedding input.
|
||||||
@ -1284,4 +1285,5 @@ class NaDiT(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
|
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
|
||||||
return NaDiTOutput(vid_sample=vid)
|
vid = unflatten(vid, vid_shape)
|
||||||
|
return vid
|
||||||
|
|||||||
@ -4,11 +4,11 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
from diffusers.models.upsampling import Upsample2D
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from model import safe_pad_operation
|
from model import safe_pad_operation
|
||||||
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
class SpatialNorm(nn.Module):
|
class SpatialNorm(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -28,6 +28,259 @@ class SpatialNorm(nn.Module):
|
|||||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||||
return new_f
|
return new_f
|
||||||
|
|
||||||
|
# partial implementation of diffusers's Attention for comfyui
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
heads: int = 8,
|
||||||
|
kv_heads: Optional[int] = None,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
bias: bool = False,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
upcast_softmax: bool = False,
|
||||||
|
added_kv_proj_dim: Optional[int] = None,
|
||||||
|
added_proj_bias: Optional[bool] = True,
|
||||||
|
norm_num_groups: Optional[int] = None,
|
||||||
|
spatial_norm_dim: Optional[int] = None,
|
||||||
|
out_bias: bool = True,
|
||||||
|
scale_qk: bool = True,
|
||||||
|
only_cross_attention: bool = False,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
rescale_output_factor: float = 1.0,
|
||||||
|
residual_connection: bool = False,
|
||||||
|
_from_deprecated_attn_block: bool = False,
|
||||||
|
out_dim: int = None,
|
||||||
|
out_context_dim: int = None,
|
||||||
|
context_pre_only=None,
|
||||||
|
pre_only=False,
|
||||||
|
is_causal: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.use_bias = bias
|
||||||
|
self.is_cross_attention = cross_attention_dim is not None
|
||||||
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
|
self.upcast_attention = upcast_attention
|
||||||
|
self.upcast_softmax = upcast_softmax
|
||||||
|
self.rescale_output_factor = rescale_output_factor
|
||||||
|
self.residual_connection = residual_connection
|
||||||
|
self.dropout = dropout
|
||||||
|
self.fused_projections = False
|
||||||
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
self.pre_only = pre_only
|
||||||
|
self.is_causal = is_causal
|
||||||
|
|
||||||
|
# we make use of this private variable to know whether this class is loaded
|
||||||
|
# with an deprecated state dict so that we can convert it on the fly
|
||||||
|
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
||||||
|
|
||||||
|
self.scale_qk = scale_qk
|
||||||
|
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||||
|
|
||||||
|
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||||
|
self.sliceable_head_dim = heads
|
||||||
|
|
||||||
|
self.added_kv_proj_dim = added_kv_proj_dim
|
||||||
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
|
if norm_num_groups is not None:
|
||||||
|
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||||
|
else:
|
||||||
|
self.group_norm = None
|
||||||
|
|
||||||
|
if spatial_norm_dim is not None:
|
||||||
|
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
||||||
|
else:
|
||||||
|
self.spatial_norm = None
|
||||||
|
|
||||||
|
self.norm_q = None
|
||||||
|
self.norm_k = None
|
||||||
|
|
||||||
|
self.norm_cross = None
|
||||||
|
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
|
||||||
|
if not self.only_cross_attention:
|
||||||
|
# only relevant for the `AddedKVProcessor` classes
|
||||||
|
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||||
|
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||||
|
else:
|
||||||
|
self.to_k = None
|
||||||
|
self.to_v = None
|
||||||
|
|
||||||
|
self.added_proj_bias = added_proj_bias
|
||||||
|
if self.added_kv_proj_dim is not None:
|
||||||
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
|
if self.context_pre_only is not None:
|
||||||
|
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
|
else:
|
||||||
|
self.add_q_proj = None
|
||||||
|
self.add_k_proj = None
|
||||||
|
self.add_v_proj = None
|
||||||
|
|
||||||
|
if not self.pre_only:
|
||||||
|
self.to_out = nn.ModuleList([])
|
||||||
|
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||||
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
self.to_out = None
|
||||||
|
|
||||||
|
if self.context_pre_only is not None and not self.context_pre_only:
|
||||||
|
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||||
|
else:
|
||||||
|
self.to_add_out = None
|
||||||
|
|
||||||
|
self.norm_added_q = None
|
||||||
|
self.norm_added_k = None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
temb: Optional[torch.Tensor] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
if self.spatial_norm is not None:
|
||||||
|
hidden_states = self.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if self.group_norm is not None:
|
||||||
|
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = self.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif self.norm_cross:
|
||||||
|
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = self.to_k(encoder_hidden_states)
|
||||||
|
value = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // self.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if self.norm_q is not None:
|
||||||
|
query = self.norm_q(query)
|
||||||
|
if self.norm_k is not None:
|
||||||
|
key = self.norm_k(key)
|
||||||
|
|
||||||
|
hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if self.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / self.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str):
|
||||||
|
"""
|
||||||
|
Inflate a 2D convolution weight matrix to a 3D one.
|
||||||
|
Parameters:
|
||||||
|
weight_2d: The weight matrix of 2D conv to be inflated.
|
||||||
|
weight_3d: The weight matrix of 3D conv to be initialized.
|
||||||
|
inflation_mode: the mode of inflation
|
||||||
|
"""
|
||||||
|
assert inflation_mode in ["tail", "replicate"]
|
||||||
|
assert weight_3d.shape[:2] == weight_2d.shape[:2]
|
||||||
|
with torch.no_grad():
|
||||||
|
if inflation_mode == "replicate":
|
||||||
|
depth = weight_3d.size(2)
|
||||||
|
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
|
||||||
|
else:
|
||||||
|
weight_3d.fill_(0.0)
|
||||||
|
weight_3d[:, :, -1].copy_(weight_2d)
|
||||||
|
return weight_3d
|
||||||
|
|
||||||
|
|
||||||
|
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str):
|
||||||
|
"""
|
||||||
|
Inflate a 2D convolution bias tensor to a 3D one
|
||||||
|
Parameters:
|
||||||
|
bias_2d: The bias tensor of 2D conv to be inflated.
|
||||||
|
bias_3d: The bias tensor of 3D conv to be initialized.
|
||||||
|
inflation_mode: Placeholder to align `inflate_weight`.
|
||||||
|
"""
|
||||||
|
assert bias_3d.shape == bias_2d.shape
|
||||||
|
with torch.no_grad():
|
||||||
|
bias_3d.copy_(bias_2d)
|
||||||
|
return bias_3d
|
||||||
|
|
||||||
|
|
||||||
|
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
|
||||||
|
"""
|
||||||
|
the main function to inflated 2D parameters to 3D.
|
||||||
|
"""
|
||||||
|
weight_name = prefix + "weight"
|
||||||
|
bias_name = prefix + "bias"
|
||||||
|
if weight_name in state_dict:
|
||||||
|
weight_2d = state_dict[weight_name]
|
||||||
|
if weight_2d.dim() == 4:
|
||||||
|
# Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
|
||||||
|
weight_3d = inflate_weight_fn(
|
||||||
|
weight_2d=weight_2d,
|
||||||
|
weight_3d=layer.weight,
|
||||||
|
inflation_mode=layer.inflation_mode,
|
||||||
|
)
|
||||||
|
state_dict[weight_name] = weight_3d
|
||||||
|
else:
|
||||||
|
return state_dict
|
||||||
|
# It's a 3d state dict, should not do inflation on both bias and weight.
|
||||||
|
if bias_name in state_dict:
|
||||||
|
bias_2d = state_dict[bias_name]
|
||||||
|
if bias_2d.dim() == 1:
|
||||||
|
# Assuming the 2D biases are 1D tensors (out_channels,)
|
||||||
|
bias_3d = inflate_bias_fn(
|
||||||
|
bias_2d=bias_2d,
|
||||||
|
bias_3d=layer.bias,
|
||||||
|
inflation_mode=layer.inflation_mode,
|
||||||
|
)
|
||||||
|
state_dict[bias_name] = bias_3d
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)):
|
if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)):
|
||||||
@ -131,15 +384,14 @@ class InflatedCausalConv3d(nn.Conv3d):
|
|||||||
def _load_from_state_dict(
|
def _load_from_state_dict(
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
):
|
):
|
||||||
# wirdly inflation_mode is pad, which would cause an assert error
|
if self.inflation_mode != "none":
|
||||||
#if self.inflation_mode != "none":
|
state_dict = modify_state_dict(
|
||||||
# state_dict = modify_state_dict(
|
self,
|
||||||
# self,
|
state_dict,
|
||||||
# state_dict,
|
prefix,
|
||||||
# prefix,
|
inflate_weight_fn=inflate_weight,
|
||||||
# inflate_weight_fn=inflate_weight,
|
inflate_bias_fn=inflate_bias,
|
||||||
# inflate_bias_fn=inflate_bias,
|
)
|
||||||
# )
|
|
||||||
super()._load_from_state_dict(
|
super()._load_from_state_dict(
|
||||||
state_dict,
|
state_dict,
|
||||||
prefix,
|
prefix,
|
||||||
@ -287,7 +539,10 @@ class Downsample3D(nn.Module):
|
|||||||
spatial_down: bool = False,
|
spatial_down: bool = False,
|
||||||
temporal_down: bool = False,
|
temporal_down: bool = False,
|
||||||
name: str = "conv",
|
name: str = "conv",
|
||||||
|
kernel_size=3,
|
||||||
|
use_conv: bool = False,
|
||||||
padding = 1,
|
padding = 1,
|
||||||
|
bias=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -295,7 +550,6 @@ class Downsample3D(nn.Module):
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
conv = self.conv
|
|
||||||
self.temporal_down = temporal_down
|
self.temporal_down = temporal_down
|
||||||
self.spatial_down = spatial_down
|
self.spatial_down = spatial_down
|
||||||
|
|
||||||
@ -305,9 +559,7 @@ class Downsample3D(nn.Module):
|
|||||||
self.temporal_kernel = 3 if temporal_down else 1
|
self.temporal_kernel = 3 if temporal_down else 1
|
||||||
self.spatial_kernel = 3 if spatial_down else 1
|
self.spatial_kernel = 3 if spatial_down else 1
|
||||||
|
|
||||||
if type(conv) in [nn.Conv2d]:
|
if use_conv:
|
||||||
# Note: lora_layer is not passed into constructor in the original implementation.
|
|
||||||
# So we make a simplification.
|
|
||||||
conv = InflatedCausalConv3d(
|
conv = InflatedCausalConv3d(
|
||||||
self.channels,
|
self.channels,
|
||||||
self.out_channels,
|
self.out_channels,
|
||||||
@ -320,20 +572,15 @@ class Downsample3D(nn.Module):
|
|||||||
),
|
),
|
||||||
inflation_mode=inflation_mode,
|
inflation_mode=inflation_mode,
|
||||||
)
|
)
|
||||||
elif type(conv) is nn.AvgPool2d:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
conv = nn.AvgPool3d(
|
conv = nn.AvgPool3d(
|
||||||
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
||||||
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
if self.name == "conv":
|
self.conv = conv
|
||||||
self.Conv2d_0 = conv
|
|
||||||
self.conv = conv
|
|
||||||
else:
|
|
||||||
self.conv = conv
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -386,6 +633,9 @@ class ResnetBlock3D(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = up
|
self.up = up
|
||||||
self.down = down
|
self.down = down
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||||
self.use_in_shortcut = use_in_shortcut
|
self.use_in_shortcut = use_in_shortcut
|
||||||
self.output_scale_factor = output_scale_factor
|
self.output_scale_factor = output_scale_factor
|
||||||
self.skip_time_act = skip_time_act
|
self.skip_time_act = skip_time_act
|
||||||
@ -394,6 +644,12 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||||
else:
|
else:
|
||||||
self.time_emb_proj = None
|
self.time_emb_proj = None
|
||||||
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
if groups_out is None:
|
||||||
|
groups_out = groups
|
||||||
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||||
|
self.use_in_shortcut = self.in_channels != out_channels
|
||||||
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.conv1 = InflatedCausalConv3d(
|
self.conv1 = InflatedCausalConv3d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
self.out_channels,
|
self.out_channels,
|
||||||
@ -405,7 +661,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
|
|
||||||
self.conv2 = InflatedCausalConv3d(
|
self.conv2 = InflatedCausalConv3d(
|
||||||
self.out_channels,
|
self.out_channels,
|
||||||
self.conv2.out_channels,
|
conv_2d_out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1,
|
padding=1,
|
||||||
@ -431,11 +687,11 @@ class ResnetBlock3D(nn.Module):
|
|||||||
if self.use_in_shortcut:
|
if self.use_in_shortcut:
|
||||||
self.conv_shortcut = InflatedCausalConv3d(
|
self.conv_shortcut = InflatedCausalConv3d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
self.conv_shortcut.out_channels,
|
conv_2d_out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=(self.conv_shortcut.bias is not None),
|
bias=True,
|
||||||
inflation_mode=inflation_mode,
|
inflation_mode=inflation_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -534,7 +790,6 @@ class DownEncoderBlock3D(nn.Module):
|
|||||||
if add_downsample:
|
if add_downsample:
|
||||||
self.downsamplers = nn.ModuleList(
|
self.downsamplers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
# [Override] Replace module.
|
|
||||||
Downsample3D(
|
Downsample3D(
|
||||||
out_channels,
|
out_channels,
|
||||||
use_conv=True,
|
use_conv=True,
|
||||||
@ -1049,8 +1304,6 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
self,
|
self,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
out_channels: int = 3,
|
out_channels: int = 3,
|
||||||
down_block_types: Tuple[str] = ("DownEncoderBlock3D",),
|
|
||||||
up_block_types: Tuple[str] = ("UpDecoderBlock3D",),
|
|
||||||
layers_per_block: int = 2,
|
layers_per_block: int = 2,
|
||||||
act_fn: str = "silu",
|
act_fn: str = "silu",
|
||||||
latent_channels: int = 16,
|
latent_channels: int = 16,
|
||||||
@ -1059,7 +1312,7 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
temporal_scale_num: int = 2,
|
temporal_scale_num: int = 2,
|
||||||
slicing_up_num: int = 0,
|
slicing_up_num: int = 0,
|
||||||
gradient_checkpoint: bool = False,
|
gradient_checkpoint: bool = False,
|
||||||
inflation_mode = "tail",
|
inflation_mode = "pad",
|
||||||
time_receptive_field: _receptive_field_t = "full",
|
time_receptive_field: _receptive_field_t = "full",
|
||||||
use_quant_conv: bool = False,
|
use_quant_conv: bool = False,
|
||||||
use_post_quant_conv: bool = False,
|
use_post_quant_conv: bool = False,
|
||||||
@ -1068,6 +1321,8 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
):
|
):
|
||||||
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
|
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
|
||||||
block_out_channels = (128, 256, 512, 512)
|
block_out_channels = (128, 256, 512, 512)
|
||||||
|
down_block_types = ("DownEncoderBlock3D",) * 4
|
||||||
|
up_block_types = ("UpDecoderBlock3D",) * 4
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# pass init params to Encoder
|
# pass init params to Encoder
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io, ui
|
from comfy_api.latest import ComfyExtension, io
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@ -9,7 +8,51 @@ from torchvision.transforms import functional as TVF
|
|||||||
from torchvision.transforms import Lambda, Normalize
|
from torchvision.transforms import Lambda, Normalize
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
|
def expand_dims(tensor, ndim):
|
||||||
|
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
||||||
|
return tensor.reshape(shape)
|
||||||
|
|
||||||
|
def get_conditions(latent, latent_blur):
|
||||||
|
t, h, w, c = latent.shape
|
||||||
|
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||||
|
cond[:, ..., :-1] = latent_blur[:]
|
||||||
|
cond[:, ..., -1:] = 1.0
|
||||||
|
return cond
|
||||||
|
|
||||||
|
def timestep_transform(timesteps, latents_shapes):
|
||||||
|
vt = 4
|
||||||
|
vs = 8
|
||||||
|
frames = (latents_shapes[:, 0] - 1) * vt + 1
|
||||||
|
heights = latents_shapes[:, 1] * vs
|
||||||
|
widths = latents_shapes[:, 2] * vs
|
||||||
|
|
||||||
|
# Compute shift factor.
|
||||||
|
def get_lin_function(x1, y1, x2, y2):
|
||||||
|
m = (y2 - y1) / (x2 - x1)
|
||||||
|
b = y1 - m * x1
|
||||||
|
return lambda x: m * x + b
|
||||||
|
|
||||||
|
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
|
||||||
|
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
|
||||||
|
shift = torch.where(
|
||||||
|
frames > 1,
|
||||||
|
vid_shift_fn(heights * widths * frames),
|
||||||
|
img_shift_fn(heights * widths),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shift timesteps.
|
||||||
|
T = 1000.0
|
||||||
|
timesteps = timesteps / T
|
||||||
|
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
|
||||||
|
timesteps = timesteps * T
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def inter(x_0, x_T, t):
|
||||||
|
t = expand_dims(t, x_0.ndim)
|
||||||
|
T = 1000.0
|
||||||
|
B = lambda t: t / T
|
||||||
|
A = lambda t: 1 - (t / T)
|
||||||
|
return A(t) * x_0 + B(t) * x_T
|
||||||
def area_resize(image, max_area):
|
def area_resize(image, max_area):
|
||||||
|
|
||||||
height, width = image.shape[-2:]
|
height, width = image.shape[-2:]
|
||||||
@ -80,7 +123,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
images = normalize(images)
|
images = normalize(images)
|
||||||
images = rearrange(images, "t c h w -> c t h w")
|
images = rearrange(images, "t c h w -> c t h w")
|
||||||
images = cut_videos(images)
|
images = cut_videos(images)
|
||||||
return
|
return io.NodeOutput(images)
|
||||||
|
|
||||||
class SeedVR2Conditioning(io.ComfyNode):
|
class SeedVR2Conditioning(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -93,16 +136,38 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
io.Conditioning.Input("text_negative_conditioning"),
|
io.Conditioning.Input("text_negative_conditioning"),
|
||||||
io.Conditioning.Input("vae_conditioning")
|
io.Conditioning.Input("vae_conditioning")
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output("positive"), io.Conditioning.Output("negative")],
|
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||||
|
io.Conditioning.Output(display_name = "negative"),
|
||||||
|
io.Latent.Output(display_name = "latent")],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
||||||
# TODO
|
# TODO: should do the flattening logic as with the original code
|
||||||
pos_cond = text_positive_conditioning[0][0]
|
pos_cond = text_positive_conditioning[0][0]
|
||||||
neg_cond = text_negative_conditioning[0][0]
|
neg_cond = text_negative_conditioning[0][0]
|
||||||
|
|
||||||
return io.NodeOutput()
|
noises = [torch.randn_like(latent) for latent in vae_conditioning]
|
||||||
|
aug_noises = [torch.randn_like(latent) for latent in vae_conditioning]
|
||||||
|
|
||||||
|
cond_noise_scale = 0.0
|
||||||
|
t = (
|
||||||
|
torch.tensor([1000.0])
|
||||||
|
* cond_noise_scale
|
||||||
|
)
|
||||||
|
shape = torch.tensor(vae_conditioning.shape[1:])[None]
|
||||||
|
t = timestep_transform(t, shape)
|
||||||
|
cond = inter(vae_conditioning, aug_noises, t)
|
||||||
|
condition = get_conditions(noises, cond)
|
||||||
|
|
||||||
|
# TODO / FIXME
|
||||||
|
pos_cond = torch.cat([condition, pos_cond], dim = 0)
|
||||||
|
neg_cond = torch.cat([condition, neg_cond], dim = 0)
|
||||||
|
|
||||||
|
negative = [[pos_cond, {}]]
|
||||||
|
positive = [[neg_cond, {}]]
|
||||||
|
|
||||||
|
return io.NodeOutput(positive, negative, noises)
|
||||||
|
|
||||||
class SeedVRExtension(ComfyExtension):
|
class SeedVRExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user