mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
testing
This commit is contained in:
parent
413ee3f687
commit
d629c8f910
@ -13,7 +13,7 @@ if model_management.xformers_enabled_vae():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
From Fairseq.
|
From Fairseq.
|
||||||
@ -24,11 +24,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
assert len(timesteps.shape) == 1
|
assert len(timesteps.shape) == 1
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||||
emb = emb.to(device=timesteps.device)
|
emb = emb.to(device=timesteps.device)
|
||||||
emb = timesteps.float()[:, None] * emb[None, :]
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
if embedding_dim % 2 == 1: # zero pad
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||||
return emb
|
return emb
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
||||||
import einops
|
import einops
|
||||||
from einops import rearrange, einsum
|
from einops import rearrange, einsum, repeat
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from math import ceil, sqrt, pi
|
from math import ceil, pi
|
||||||
import torch
|
import torch
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
|
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
|
||||||
@ -12,6 +12,7 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
from comfy.rmsnorm import RMSNorm
|
from comfy.rmsnorm import RMSNorm
|
||||||
from torch.nn.modules.utils import _triple
|
from torch.nn.modules.utils import _triple
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(self, disable=False, prefix="", cache=None):
|
def __init__(self, disable=False, prefix="", cache=None):
|
||||||
@ -354,8 +355,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
freqs = self.freqs
|
freqs = self.freqs
|
||||||
|
|
||||||
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
||||||
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2)
|
||||||
|
|
||||||
if should_cache and offset == 0:
|
if should_cache and offset == 0:
|
||||||
self.cached_freqs[:seq_len] = freqs.detach()
|
self.cached_freqs[:seq_len] = freqs.detach()
|
||||||
@ -460,6 +461,7 @@ def apply_rotary_emb(
|
|||||||
t_middle = t[..., start_index:end_index]
|
t_middle = t[..., start_index:end_index]
|
||||||
t_right = t[..., end_index:]
|
t_right = t[..., end_index:]
|
||||||
|
|
||||||
|
freqs = freqs.to(t_middle.device)
|
||||||
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
||||||
|
|
||||||
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
||||||
@ -560,6 +562,7 @@ class MMModule(nn.Module):
|
|||||||
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
|
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
|
||||||
if not self.vid_only:
|
if not self.vid_only:
|
||||||
txt_module = self.txt if not self.shared_weights else self.all
|
txt_module = self.txt if not self.shared_weights else self.all
|
||||||
|
txt = txt.to(device=vid.device, dtype=vid.dtype)
|
||||||
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
|
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
|
||||||
return vid, txt
|
return vid, txt
|
||||||
|
|
||||||
@ -747,6 +750,7 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
|
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
|
||||||
|
|
||||||
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
|
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
|
||||||
|
txt_len = txt_len.to(window_count.device)
|
||||||
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
|
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
|
||||||
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
|
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
|
||||||
concat_win, unconcat_win = cache_win(
|
concat_win, unconcat_win = cache_win(
|
||||||
@ -1122,8 +1126,12 @@ class TimeEmbedding(nn.Module):
|
|||||||
emb = emb.to(dtype)
|
emb = emb.to(dtype)
|
||||||
emb = self.proj_in(emb)
|
emb = self.proj_in(emb)
|
||||||
emb = self.act(emb)
|
emb = self.act(emb)
|
||||||
|
device = next(self.proj_hid.parameters()).device
|
||||||
|
emb = emb.to(device)
|
||||||
emb = self.proj_hid(emb)
|
emb = self.proj_hid(emb)
|
||||||
emb = self.act(emb)
|
emb = self.act(emb)
|
||||||
|
device = next(self.proj_out.parameters()).device
|
||||||
|
emb = emb.to(device)
|
||||||
emb = self.proj_out(emb)
|
emb = self.proj_out(emb)
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
@ -1206,6 +1214,8 @@ class NaDiT(nn.Module):
|
|||||||
elif len(block_type) != num_layers:
|
elif len(block_type) != num_layers:
|
||||||
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
|
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.register_parameter("positive_conditioning", torch.empty((58, 5120)))
|
||||||
|
self.register_parameter("negative_conditioning", torch.empty((64, 5120)))
|
||||||
self.vid_in = NaPatchIn(
|
self.vid_in = NaPatchIn(
|
||||||
in_channels=vid_in_channels,
|
in_channels=vid_in_channels,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
@ -1303,11 +1313,9 @@ class NaDiT(nn.Module):
|
|||||||
conditions = kwargs.get("condition")
|
conditions = kwargs.get("condition")
|
||||||
|
|
||||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||||
# txt_shape should be the same for both
|
pos_cond, txt_shape = flatten([pos_cond])
|
||||||
pos_cond, txt_shape = flatten(pos_cond)
|
neg_cond, _ = flatten([neg_cond])
|
||||||
neg_cond, _ = flatten(neg_cond)
|
|
||||||
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
||||||
txt_shape[0] *= 2
|
|
||||||
|
|
||||||
vid = x
|
vid = x
|
||||||
vid, vid_shape = flatten(x)
|
vid, vid_shape = flatten(x)
|
||||||
@ -1321,7 +1329,7 @@ class NaDiT(nn.Module):
|
|||||||
dtype = next(self.parameters()).dtype
|
dtype = next(self.parameters()).dtype
|
||||||
txt = txt.to(device).to(dtype)
|
txt = txt.to(device).to(dtype)
|
||||||
vid = vid.to(device).to(dtype)
|
vid = vid.to(device).to(dtype)
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device))
|
||||||
|
|
||||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||||
|
|
||||||
|
|||||||
@ -801,7 +801,8 @@ class SeedVR2(BaseModel):
|
|||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
condition = kwargs.get("condition", None)
|
condition = kwargs.get("condition", None)
|
||||||
out["condition"] = comfy.conds.CONDRegular(condition)
|
if condition is not None:
|
||||||
|
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class PixArt(BaseModel):
|
class PixArt(BaseModel):
|
||||||
|
|||||||
@ -1168,7 +1168,7 @@ class SeedVR2(supported_models_base.BASE):
|
|||||||
out = model_base.SeedVR2(self, device=device)
|
out = model_base.SeedVR2(self, device=device)
|
||||||
return out
|
return out
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel)
|
return None
|
||||||
|
|
||||||
class ACEStep(supported_models_base.BASE):
|
class ACEStep(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
|
|||||||
@ -15,9 +15,9 @@ def expand_dims(tensor, ndim):
|
|||||||
|
|
||||||
def get_conditions(latent, latent_blur):
|
def get_conditions(latent, latent_blur):
|
||||||
t, h, w, c = latent.shape
|
t, h, w, c = latent.shape
|
||||||
cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype)
|
||||||
cond[:, ..., :-1] = latent_blur[:]
|
#cond[:, ..., :-1] = latent_blur[:]
|
||||||
cond[:, ..., -1:] = 1.0
|
#cond[:, ..., -1:] = 1.0
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def timestep_transform(timesteps, latents_shapes):
|
def timestep_transform(timesteps, latents_shapes):
|
||||||
@ -144,6 +144,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
vae = vae.to(device)
|
vae = vae.to(device)
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
latent = vae.encode(images)[0]
|
latent = vae.encode(images)[0]
|
||||||
|
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||||
|
latent = rearrange(latent, "b c ... -> b ... c")
|
||||||
|
|
||||||
latent = (latent - shift) * scale
|
latent = (latent - shift) * scale
|
||||||
|
|
||||||
@ -156,9 +158,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
node_id="SeedVR2Conditioning",
|
node_id="SeedVR2Conditioning",
|
||||||
category="image/video",
|
category="image/video",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("text_positive_conditioning"),
|
io.Latent.Input("vae_conditioning"),
|
||||||
io.Conditioning.Input("text_negative_conditioning"),
|
io.Model.Input("model"),
|
||||||
io.Latent.Input("vae_conditioning")
|
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output(display_name = "positive"),
|
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||||
io.Conditioning.Output(display_name = "negative"),
|
io.Conditioning.Output(display_name = "negative"),
|
||||||
@ -166,12 +167,13 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
|
||||||
|
|
||||||
vae_conditioning = vae_conditioning["samples"]
|
vae_conditioning = vae_conditioning["samples"]
|
||||||
device = vae_conditioning.device
|
device = vae_conditioning.device
|
||||||
pos_cond = text_positive_conditioning[0][0]
|
model = model.model.diffusion_model
|
||||||
neg_cond = text_negative_conditioning[0][0]
|
pos_cond = model.positive_conditioning
|
||||||
|
neg_cond = model.negative_conditioning
|
||||||
|
|
||||||
noises = torch.randn_like(vae_conditioning).to(device)
|
noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
@ -181,21 +183,21 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
torch.tensor([1000.0])
|
torch.tensor([1000.0])
|
||||||
* cond_noise_scale
|
* cond_noise_scale
|
||||||
).to(device)
|
).to(device)
|
||||||
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None]
|
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim
|
||||||
t = timestep_transform(t, shape)
|
t = timestep_transform(t, shape)
|
||||||
cond = inter(vae_conditioning, aug_noises, t)
|
cond = inter(vae_conditioning, aug_noises, t)
|
||||||
condition = get_conditions(noises, cond)
|
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
|
||||||
|
|
||||||
pos_shape = pos_cond.shape[1]
|
pos_shape = pos_cond.shape[0]
|
||||||
neg_shape = neg_cond.shape[1]
|
neg_shape = neg_cond.shape[0]
|
||||||
diff = abs(pos_shape - neg_shape)
|
diff = abs(pos_shape - neg_shape)
|
||||||
if pos_shape > neg_shape:
|
if pos_shape > neg_shape:
|
||||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||||
else:
|
else:
|
||||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||||
|
|
||||||
negative = [[pos_cond, {"condition": condition}]]
|
negative = [[neg_cond, {"condition": condition}]]
|
||||||
positive = [[neg_cond, {"condition": condition}]]
|
positive = [[pos_cond, {"condition": condition}]]
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user