mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 04:50:49 +08:00
..
This commit is contained in:
parent
d629c8f910
commit
768c9cedf8
@ -1,8 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
||||||
import einops
|
import einops
|
||||||
from einops import rearrange, einsum, repeat
|
from einops import rearrange, repeat
|
||||||
|
import comfy.model_management
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.utils.rnn as rnn_utils
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from math import ceil, pi
|
from math import ceil, pi
|
||||||
import torch
|
import torch
|
||||||
@ -559,6 +561,8 @@ class MMModule(nn.Module):
|
|||||||
torch.FloatTensor,
|
torch.FloatTensor,
|
||||||
]:
|
]:
|
||||||
vid_module = self.vid if not self.shared_weights else self.all
|
vid_module = self.vid if not self.shared_weights else self.all
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
vid = vid.to(device)
|
||||||
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
|
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
|
||||||
if not self.vid_only:
|
if not self.vid_only:
|
||||||
txt_module = self.txt if not self.shared_weights else self.all
|
txt_module = self.txt if not self.shared_weights else self.all
|
||||||
@ -616,58 +620,8 @@ class NaMMAttention(nn.Module):
|
|||||||
|
|
||||||
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
|
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(self):
|
||||||
self,
|
pass
|
||||||
vid: torch.FloatTensor, # l c
|
|
||||||
txt: torch.FloatTensor, # l c
|
|
||||||
vid_shape: torch.LongTensor, # b 3
|
|
||||||
txt_shape: torch.LongTensor, # b 1
|
|
||||||
cache: Cache,
|
|
||||||
) -> Tuple[
|
|
||||||
torch.FloatTensor,
|
|
||||||
torch.FloatTensor,
|
|
||||||
]:
|
|
||||||
|
|
||||||
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
|
|
||||||
vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
|
|
||||||
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
|
|
||||||
|
|
||||||
vid_q, vid_k, vid_v = vid_qkv.unbind(1)
|
|
||||||
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
|
|
||||||
|
|
||||||
vid_q, txt_q = self.norm_q(vid_q, txt_q)
|
|
||||||
vid_k, txt_k = self.norm_k(vid_k, txt_k)
|
|
||||||
|
|
||||||
if self.rope:
|
|
||||||
if self.rope.mm:
|
|
||||||
vid_q, vid_k, txt_q, txt_k = self.rope(
|
|
||||||
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache)
|
|
||||||
|
|
||||||
vid_len = cache("vid_len", lambda: vid_shape.prod(-1))
|
|
||||||
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
|
|
||||||
all_len = cache("all_len", lambda: vid_len + txt_len)
|
|
||||||
|
|
||||||
b = len(vid_len)
|
|
||||||
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
|
|
||||||
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
|
|
||||||
|
|
||||||
q = torch.cat([vq, tq], dim=1)
|
|
||||||
k = torch.cat([vk, tk], dim=1)
|
|
||||||
v = torch.cat([vv, tv], dim=1)
|
|
||||||
|
|
||||||
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
|
|
||||||
|
|
||||||
attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
|
|
||||||
attn = attn.flatten(0, 1) # to continue working with the rest of the code
|
|
||||||
|
|
||||||
attn = rearrange(attn, "l h d -> l (h d)")
|
|
||||||
vid_out, txt_out = unconcat(attn)
|
|
||||||
|
|
||||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
|
||||||
return vid_out, txt_out
|
|
||||||
|
|
||||||
def window(
|
def window(
|
||||||
hid: torch.FloatTensor, # (L c)
|
hid: torch.FloatTensor, # (L c)
|
||||||
@ -783,23 +737,78 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||||
|
|
||||||
# TODO: continue testing
|
# TODO: continue testing
|
||||||
b = len(vid_len_win)
|
v_lens = vid_len_win.cpu().tolist()
|
||||||
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
|
t_lens_batch = txt_len.cpu().tolist()
|
||||||
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
|
win_counts = window_count.cpu().tolist()
|
||||||
|
|
||||||
q = torch.cat([vq, tq], dim=1)
|
vq_l = torch.split(vid_q, v_lens)
|
||||||
k = torch.cat([vk, tk], dim=1)
|
vk_l = torch.split(vid_k, v_lens)
|
||||||
v = torch.cat([vv, tv], dim=1)
|
vv_l = torch.split(vid_v, v_lens)
|
||||||
out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True)
|
|
||||||
out = out.flatten(0, 1)
|
tv_batch = torch.split(txt_v, t_lens_batch)
|
||||||
|
tv_l = []
|
||||||
|
for i, count in enumerate(win_counts):
|
||||||
|
tv_l.extend([tv_batch[i]] * count)
|
||||||
|
|
||||||
|
current_txt_len = txt_q.shape[0]
|
||||||
|
expected_batch_len = sum(t_lens_batch)
|
||||||
|
|
||||||
|
if current_txt_len != expected_batch_len:
|
||||||
|
t_lens_win = txt_len_win.cpu().tolist()
|
||||||
|
|
||||||
|
tq_l = torch.split(txt_q, t_lens_win)
|
||||||
|
tk_l = torch.split(txt_k, t_lens_win)
|
||||||
|
else:
|
||||||
|
tq_batch = torch.split(txt_q, t_lens_batch)
|
||||||
|
tk_batch = torch.split(txt_k, t_lens_batch)
|
||||||
|
|
||||||
|
tq_l = []
|
||||||
|
tk_l = []
|
||||||
|
for i, count in enumerate(win_counts):
|
||||||
|
tq_l.extend([tq_batch[i]] * count)
|
||||||
|
tk_l.extend([tk_batch[i]] * count)
|
||||||
|
|
||||||
|
q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)]
|
||||||
|
k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)]
|
||||||
|
v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)]
|
||||||
|
|
||||||
|
q = rnn_utils.pad_sequence(q_list, batch_first=True)
|
||||||
|
k = rnn_utils.pad_sequence(k_list, batch_first=True)
|
||||||
|
v = rnn_utils.pad_sequence(v_list, batch_first=True)
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
B, Heads, Max_L, _ = q.shape
|
||||||
|
combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)]
|
||||||
|
|
||||||
|
attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype)
|
||||||
|
idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L)
|
||||||
|
len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1)
|
||||||
|
|
||||||
|
padding_mask = idx >= len_tensor
|
||||||
|
attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
|
||||||
|
out_flat_list = []
|
||||||
|
for i, length in enumerate(combined_lens):
|
||||||
|
out_flat_list.append(out[i, :length])
|
||||||
|
|
||||||
|
out = torch.cat(out_flat_list, dim=0)
|
||||||
|
|
||||||
# text pooling
|
|
||||||
vid_out, txt_out = unconcat_win(out)
|
vid_out, txt_out = unconcat_win(out)
|
||||||
|
|
||||||
vid_out = rearrange(vid_out, "l h d -> l (h d)")
|
vid_out = rearrange(vid_out, "l h d -> l (h d)")
|
||||||
txt_out = rearrange(txt_out, "l h d -> l (h d)")
|
txt_out = rearrange(txt_out, "l h d -> l (h d)")
|
||||||
vid_out = window_reverse(vid_out)
|
vid_out = window_reverse(vid_out)
|
||||||
|
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
vid_out, txt_out = vid_out.to(device), txt_out.to(device)
|
||||||
|
self.proj_out = self.proj_out.to(device)
|
||||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||||
|
|
||||||
return vid_out, txt_out
|
return vid_out, txt_out
|
||||||
@ -837,6 +846,8 @@ class SwiGLUMLP(nn.Module):
|
|||||||
self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
|
self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
|
||||||
|
|
||||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
x = x.to(next(self.proj_in.parameters()).device)
|
||||||
|
self.proj_out = self.proj_out.to(x.device)
|
||||||
x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
|
x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -928,6 +939,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
|
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
|
||||||
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
|
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
|
||||||
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
|
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
|
||||||
|
txt = txt.to(txt_attn.device)
|
||||||
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
|
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
|
||||||
|
|
||||||
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
|
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
|
||||||
@ -967,12 +979,11 @@ class NaPatchOut(PatchOut):
|
|||||||
vid: torch.FloatTensor, # l c
|
vid: torch.FloatTensor, # l c
|
||||||
vid_shape: torch.LongTensor,
|
vid_shape: torch.LongTensor,
|
||||||
cache: Cache = Cache(disable=True), # for test
|
cache: Cache = Cache(disable=True), # for test
|
||||||
|
vid_shape_before_patchify = None
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.FloatTensor,
|
torch.FloatTensor,
|
||||||
torch.LongTensor,
|
torch.LongTensor,
|
||||||
]:
|
]:
|
||||||
cache = cache.namespace("patch")
|
|
||||||
vid_shape_before_patchify = cache.get("vid_shape_before_patchify")
|
|
||||||
|
|
||||||
t, h, w = self.patch_size
|
t, h, w = self.patch_size
|
||||||
vid = self.proj(vid)
|
vid = self.proj(vid)
|
||||||
@ -1074,6 +1085,16 @@ class AdaSingle(nn.Module):
|
|||||||
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
|
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
|
||||||
emb = expand_dims(emb, 1, hid.ndim + 1)
|
emb = expand_dims(emb, 1, hid.ndim + 1)
|
||||||
|
|
||||||
|
if hid_len is not None:
|
||||||
|
slice_inputs = lambda x, dim: x
|
||||||
|
emb = cache(
|
||||||
|
f"emb_repeat_{idx}_{branch_tag}",
|
||||||
|
lambda: slice_inputs(
|
||||||
|
torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
shiftA, scaleA, gateA = emb.unbind(-1)
|
shiftA, scaleA, gateA = emb.unbind(-1)
|
||||||
shiftB, scaleB, gateB = (
|
shiftB, scaleB, gateB = (
|
||||||
getattr(self, f"{layer}_shift", None),
|
getattr(self, f"{layer}_shift", None),
|
||||||
@ -1214,8 +1235,8 @@ class NaDiT(nn.Module):
|
|||||||
elif len(block_type) != num_layers:
|
elif len(block_type) != num_layers:
|
||||||
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
|
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_parameter("positive_conditioning", torch.empty((58, 5120)))
|
self.register_buffer("positive_conditioning", torch.empty((58, 5120)))
|
||||||
self.register_parameter("negative_conditioning", torch.empty((64, 5120)))
|
self.register_buffer("negative_conditioning", torch.empty((64, 5120)))
|
||||||
self.vid_in = NaPatchIn(
|
self.vid_in = NaPatchIn(
|
||||||
in_channels=vid_in_channels,
|
in_channels=vid_in_channels,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
@ -1306,13 +1327,14 @@ class NaDiT(nn.Module):
|
|||||||
x,
|
x,
|
||||||
timestep,
|
timestep,
|
||||||
context, # l c
|
context, # l c
|
||||||
disable_cache: bool = True, # for test # TODO ?
|
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
conditions = kwargs.get("condition")
|
conditions = kwargs.get("condition")
|
||||||
|
|
||||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0)
|
||||||
|
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
||||||
pos_cond, txt_shape = flatten([pos_cond])
|
pos_cond, txt_shape = flatten([pos_cond])
|
||||||
neg_cond, _ = flatten([neg_cond])
|
neg_cond, _ = flatten([neg_cond])
|
||||||
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
||||||
@ -1331,6 +1353,7 @@ class NaDiT(nn.Module):
|
|||||||
vid = vid.to(device).to(dtype)
|
vid = vid.to(device).to(dtype)
|
||||||
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
||||||
|
|
||||||
|
vid_shape_before_patchify = vid_shape
|
||||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||||
|
|
||||||
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
||||||
@ -1358,6 +1381,6 @@ class NaDiT(nn.Module):
|
|||||||
branch_tag="vid",
|
branch_tag="vid",
|
||||||
)
|
)
|
||||||
|
|
||||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
|
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
||||||
vid = unflatten(vid, vid_shape)
|
vid = unflatten(vid, vid_shape)
|
||||||
return vid
|
return vid[0]
|
||||||
|
|||||||
@ -6,9 +6,31 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution(object):
|
||||||
|
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(
|
||||||
|
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||||
|
sample = torch.randn(
|
||||||
|
self.mean.shape,
|
||||||
|
generator=generator,
|
||||||
|
device=self.parameters.device,
|
||||||
|
dtype=self.parameters.dtype,
|
||||||
|
)
|
||||||
|
x = self.mean + self.std * sample
|
||||||
|
return x
|
||||||
|
|
||||||
class SpatialNorm(nn.Module):
|
class SpatialNorm(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -453,7 +475,7 @@ class Upsample3D(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.Conv2d_0 = conv
|
self.Conv2d_0 = conv
|
||||||
|
|
||||||
self.norm = False
|
self.norm = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1255,6 +1277,7 @@ class Decoder3D(nn.Module):
|
|||||||
latent_embeds: Optional[torch.FloatTensor] = None,
|
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
|
||||||
|
sample = sample.to(next(self.parameters()).device)
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||||
@ -1397,10 +1420,10 @@ class VideoAutoencoderKL(nn.Module):
|
|||||||
def _decode(
|
def _decode(
|
||||||
self, z: torch.Tensor
|
self, z: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_z = z.to(self.device)
|
latent = z.to(self.device)
|
||||||
if self.post_quant_conv is not None:
|
if self.post_quant_conv is not None:
|
||||||
_z = self.post_quant_conv(_z)
|
latent = self.post_quant_conv(latent)
|
||||||
output = self.decoder(_z)
|
output = self.decoder(latent)
|
||||||
return output.to(z.device)
|
return output.to(z.device)
|
||||||
|
|
||||||
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -1473,9 +1496,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
return z, p
|
return z, p
|
||||||
|
|
||||||
def decode(self, z: torch.FloatTensor):
|
def decode(self, z: torch.FloatTensor):
|
||||||
|
latent = z.unsqueeze(0)
|
||||||
|
scale = 0.9152
|
||||||
|
shift = 0
|
||||||
|
latent = latent / scale + shift
|
||||||
|
latent = rearrange(latent, "b ... c -> b c ...")
|
||||||
|
latent = latent.squeeze(2)
|
||||||
if z.ndim == 4:
|
if z.ndim == 4:
|
||||||
z = z.unsqueeze(2)
|
z = z.unsqueeze(2)
|
||||||
x = super().decode(z).sample.squeeze(2)
|
x = super().decode(latent).squeeze(2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def preprocess(self, x: torch.Tensor):
|
def preprocess(self, x: torch.Tensor):
|
||||||
|
|||||||
@ -15,9 +15,9 @@ def expand_dims(tensor, ndim):
|
|||||||
|
|
||||||
def get_conditions(latent, latent_blur):
|
def get_conditions(latent, latent_blur):
|
||||||
t, h, w, c = latent.shape
|
t, h, w, c = latent.shape
|
||||||
cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype)
|
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||||
#cond[:, ..., :-1] = latent_blur[:]
|
cond[:, ..., :-1] = latent_blur[:]
|
||||||
#cond[:, ..., -1:] = 1.0
|
cond[:, ..., -1:] = 1.0
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def timestep_transform(timesteps, latents_shapes):
|
def timestep_transform(timesteps, latents_shapes):
|
||||||
@ -117,6 +117,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, vae, resolution_height, resolution_width):
|
def execute(cls, images, vae, resolution_height, resolution_width):
|
||||||
device = vae.patcher.load_device
|
device = vae.patcher.load_device
|
||||||
|
offload_device = vae.patcher.offload_device
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
scale = 0.9152; shift = 0
|
scale = 0.9152; shift = 0
|
||||||
|
|
||||||
@ -144,6 +145,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
vae = vae.to(device)
|
vae = vae.to(device)
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
latent = vae.encode(images)[0]
|
latent = vae.encode(images)[0]
|
||||||
|
vae = vae.to(offload_device)
|
||||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||||
latent = rearrange(latent, "b c ... -> b ... c")
|
latent = rearrange(latent, "b c ... -> b ... c")
|
||||||
|
|
||||||
@ -196,8 +198,9 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||||
|
|
||||||
negative = [[neg_cond, {"condition": condition}]]
|
cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0)
|
||||||
positive = [[pos_cond, {"condition": condition}]]
|
negative = [[cond, {"condition": condition}]]
|
||||||
|
positive = [[cond, {"condition": condition}]]
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user