mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-29 15:50:22 +08:00
Compare commits
48 Commits
72f725c13b
...
5f6232b6e0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f6232b6e0 | ||
|
|
c6238047ee | ||
|
|
5cd1113236 | ||
|
|
fffb96ad42 | ||
|
|
a506be2486 | ||
|
|
dbf8f9dcf9 | ||
|
|
72ca18acc2 | ||
|
|
f588e6c821 | ||
|
|
0da072e098 | ||
|
|
31d358c78c | ||
|
|
4dd42ef1b7 | ||
|
|
02529c6d57 | ||
|
|
49febe15c3 | ||
|
|
84fa155071 | ||
|
|
4691717340 | ||
|
|
fadc7839cc | ||
|
|
3039c7ba14 | ||
|
|
9b573da39b | ||
|
|
4d7012ecda | ||
|
|
21bc67d7db | ||
|
|
7b2e5ef0af | ||
|
|
1afc2ed8e6 | ||
|
|
d41b1111eb | ||
|
|
5b0c80a093 | ||
|
|
e30298dda2 | ||
|
|
98b6bfcb71 | ||
|
|
fc5fabb629 | ||
|
|
5db5da790f | ||
|
|
a4e9d071e8 | ||
|
|
4fe772fae9 | ||
|
|
0d2044a778 | ||
|
|
7e62f8cc9f | ||
|
|
74621b9d86 | ||
|
|
db74a27870 | ||
|
|
acb9a11c6f | ||
|
|
d9f71da998 | ||
|
|
183b377588 | ||
|
|
ebd945ce3d | ||
|
|
58e7cea796 | ||
|
|
768c9cedf8 | ||
|
|
d629c8f910 | ||
|
|
413ee3f687 | ||
|
|
d12702ee0b | ||
|
|
f030b3afc8 | ||
|
|
44a5bf353a | ||
|
|
4b9332cc21 | ||
|
|
041dbd6a8a | ||
|
|
08d93555d0 |
@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
|
||||
@ -747,6 +747,10 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class SeedVR2(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 16
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
|
||||
|
||||
@ -19,9 +19,15 @@ if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
try:
|
||||
from sageattention import sageattn_varlen
|
||||
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True
|
||||
except:
|
||||
pass
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
@ -39,7 +45,7 @@ except ImportError:
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.flash_attention_enabled():
|
||||
@ -87,7 +93,13 @@ def default(val, d):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||
@ -412,13 +424,14 @@ except:
|
||||
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
if b * heads > 65535 and not var_length:
|
||||
disabled_xformers = True
|
||||
|
||||
if not disabled_xformers:
|
||||
@ -426,9 +439,27 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs)
|
||||
|
||||
if skip_reshape:
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
total_tokens, hidden_dim = q.shape
|
||||
dim_head = hidden_dim // heads
|
||||
q = q.view(1, total_tokens, heads, dim_head)
|
||||
k = k.view(1, total_tokens, heads, dim_head)
|
||||
v = v.view(1, total_tokens, heads, dim_head)
|
||||
else:
|
||||
if q.ndim == 3:
|
||||
q = q.unsqueeze(0)
|
||||
if k.ndim == 3:
|
||||
k = k.unsqueeze(0)
|
||||
if v.ndim == 3:
|
||||
v = v.unsqueeze(0)
|
||||
dim_head = q.shape[-1]
|
||||
|
||||
target_output_shape = (q.shape[1], -1)
|
||||
b = 1
|
||||
elif skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
q, k, v = map(
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
@ -442,7 +473,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
if var_length:
|
||||
cu_seqlens_q, _, _, _ = var_attn_arg(kwargs)
|
||||
seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||
mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens)
|
||||
elif mask is not None:
|
||||
# add a singleton batch dimension
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
@ -464,6 +499,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if var_length:
|
||||
return out.reshape(*target_output_shape)
|
||||
if skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
@ -481,7 +518,28 @@ else:
|
||||
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
mask = None
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
elif skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
@ -499,8 +557,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
if SDP_BATCH_LIMIT >= b or var_length:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if var_length:
|
||||
return out.contiguous().transpose(1, 2).values()
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -524,8 +584,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
total_tokens, hidden_dim = q.shape
|
||||
dim_head = hidden_dim // heads
|
||||
q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)]
|
||||
b, _, dim_head = q.shape
|
||||
# skips batched code
|
||||
mask = None
|
||||
tensor_layout = "VAR"
|
||||
target_output_shape = (q.shape[0], -1)
|
||||
elif skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
else:
|
||||
@ -546,7 +617,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE:
|
||||
raise ValueError("Sage Attention two is required to run variable length attention.")
|
||||
elif var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
|
||||
out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale)
|
||||
else:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
@ -556,7 +634,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
@ -567,6 +645,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if skip_output_reshape:
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
if var_length:
|
||||
return out.view(*target_output_shape)
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@ -678,6 +758,15 @@ except AttributeError as error:
|
||||
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
return flash_attn_varlen_func(
|
||||
q=q, k=k, v=v,
|
||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False
|
||||
)
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
|
||||
@ -13,13 +13,14 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
return xl[0]
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
@ -30,11 +31,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
1431
comfy/ldm/seedvr/model.py
Normal file
1431
comfy/ldm/seedvr/model.py
Normal file
File diff suppressed because it is too large
Load Diff
1936
comfy/ldm/seedvr/vae.py
Normal file
1936
comfy/ldm/seedvr/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -47,6 +47,8 @@ import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.seedvr.model
|
||||
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
|
||||
@ -815,6 +817,16 @@ class HunyuanDiT(BaseModel):
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
condition = kwargs.get("condition", None)
|
||||
if condition is not None:
|
||||
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||
return out
|
||||
|
||||
class PixArt(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||
|
||||
@ -449,6 +449,28 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["mlp_type"] = "normal"
|
||||
return dit_config
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 2560
|
||||
dit_config["heads"] = 20
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["norm_eps"] = 1.0e-05
|
||||
dit_config["qk_rope"] = None
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
dit_config["vid_out_norm"] = True
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
|
||||
0
comfy/samplers.py
Executable file → Normal file
0
comfy/samplers.py
Executable file → Normal file
18
comfy/sd.py
18
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.seedvr.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
@ -312,7 +313,10 @@ class CLIP:
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
if (metadata is not None and metadata["keep_diffusers_format"] == "true"):
|
||||
pass
|
||||
else:
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
@ -379,6 +383,17 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.process_input = lambda image: image
|
||||
self.crop_input = False
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
@ -486,6 +501,7 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
|
||||
@ -1303,6 +1303,25 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix = "", device=None):
|
||||
out = model_base.SeedVR2(self, device=device)
|
||||
return out
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
unet_config = {
|
||||
"image_model": "chroma_radiance",
|
||||
@ -1551,6 +1570,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, SeedVR2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -567,7 +567,7 @@ async def execute_lipsync(
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = await upload_audio_to_comfyapi(
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg"
|
||||
)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
|
||||
@ -55,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: str | None = None,
|
||||
*,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
@ -75,7 +75,7 @@ def tensor_to_bytesio(
|
||||
|
||||
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
return img_binary
|
||||
|
||||
|
||||
|
||||
@ -82,7 +82,6 @@ async def upload_audio_to_comfyapi(
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
@ -92,7 +91,7 @@ async def upload_audio_to_comfyapi(
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
|
||||
465
comfy_extras/nodes_seedvr.py
Normal file
465
comfy_extras/nodes_seedvr.py
Normal file
@ -0,0 +1,465 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import torch
|
||||
import math
|
||||
from einops import rearrange
|
||||
|
||||
import gc
|
||||
import comfy.model_management
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.transforms import Lambda, Normalize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
x = x.to(next(vae_model.parameters()).dtype)
|
||||
if x.ndim != 5:
|
||||
x = x.unsqueeze(2)
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||
|
||||
if encode:
|
||||
ti_h, ti_w = tile_size
|
||||
ov_h, ov_w = tile_overlap
|
||||
target_d = (d + sf_t - 1) // sf_t
|
||||
target_h = (h + sf_s - 1) // sf_s
|
||||
target_w = (w + sf_s - 1) // sf_s
|
||||
else:
|
||||
ti_h = max(1, tile_size[0] // sf_s)
|
||||
ti_w = max(1, tile_size[1] // sf_s)
|
||||
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||
|
||||
target_d = d * sf_t
|
||||
target_h = h * sf_s
|
||||
target_w = w * sf_s
|
||||
|
||||
stride_h = max(1, ti_h - ov_h)
|
||||
stride_w = max(1, ti_w - ov_w)
|
||||
|
||||
storage_device = vae_model.device
|
||||
result = None
|
||||
count = None
|
||||
|
||||
def run_temporal_chunks(spatial_tile):
|
||||
chunk_results = []
|
||||
t_dim_size = spatial_tile.shape[2]
|
||||
|
||||
if encode:
|
||||
input_chunk = temporal_size
|
||||
else:
|
||||
input_chunk = max(1, temporal_size // sf_t)
|
||||
for i in range(0, t_dim_size, input_chunk):
|
||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||
current_valid_len = t_chunk.shape[2]
|
||||
|
||||
pad_amount = 0
|
||||
if current_valid_len < input_chunk:
|
||||
pad_amount = input_chunk - current_valid_len
|
||||
|
||||
last_frame = t_chunk[:, :, -1:, :, :]
|
||||
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
|
||||
|
||||
t_chunk = torch.cat([t_chunk, padding], dim=2)
|
||||
t_chunk = t_chunk.contiguous()
|
||||
|
||||
if encode:
|
||||
out = vae_model.encode(t_chunk)[0]
|
||||
else:
|
||||
out = vae_model.decode_(t_chunk)
|
||||
|
||||
if isinstance(out, (tuple, list)):
|
||||
out = out[0]
|
||||
if out.ndim == 4:
|
||||
out = out.unsqueeze(2)
|
||||
|
||||
if pad_amount > 0:
|
||||
if encode:
|
||||
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
|
||||
out = out[:, :, :expected_valid_out, :, :]
|
||||
|
||||
else:
|
||||
expected_valid_out = current_valid_len * sf_t
|
||||
out = out[:, :, :expected_valid_out, :, :]
|
||||
|
||||
chunk_results.append(out.to(storage_device))
|
||||
|
||||
return torch.cat(chunk_results, dim=2)
|
||||
|
||||
ramp_cache = {}
|
||||
def get_ramp(steps):
|
||||
if steps not in ramp_cache:
|
||||
t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32)
|
||||
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
||||
return ramp_cache[steps]
|
||||
|
||||
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
|
||||
bar = ProgressBar(total_tiles)
|
||||
|
||||
for y_idx in range(0, h, stride_h):
|
||||
y_end = min(y_idx + ti_h, h)
|
||||
|
||||
for x_idx in range(0, w, stride_w):
|
||||
x_end = min(x_idx + ti_w, w)
|
||||
|
||||
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
||||
|
||||
# Run VAE
|
||||
tile_out = run_temporal_chunks(tile_x)
|
||||
|
||||
if result is None:
|
||||
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
|
||||
if encode:
|
||||
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||
else:
|
||||
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||
|
||||
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||
|
||||
if cur_ov_h > 0:
|
||||
r = get_ramp(cur_ov_h)
|
||||
if y_idx > 0:
|
||||
w_h[:cur_ov_h] = r
|
||||
if y_end < h:
|
||||
w_h[-cur_ov_h:] = 1.0 - r
|
||||
|
||||
if cur_ov_w > 0:
|
||||
r = get_ramp(cur_ov_w)
|
||||
if x_idx > 0:
|
||||
w_w[:cur_ov_w] = r
|
||||
if x_end < w:
|
||||
w_w[-cur_ov_w:] = 1.0 - r
|
||||
|
||||
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||
|
||||
valid_d = min(tile_out.shape[2], result.shape[2])
|
||||
tile_out = tile_out[:, :, :valid_d, :, :]
|
||||
|
||||
tile_out.mul_(final_weight)
|
||||
|
||||
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
||||
count[:, :, :, ys:ye, xs:xe] += final_weight
|
||||
|
||||
del tile_out, final_weight, tile_x, w_h, w_w
|
||||
bar.update(1)
|
||||
|
||||
result.div_(count.clamp(min=1e-6))
|
||||
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
|
||||
if x.shape[2] == 1 and sf_t == 1:
|
||||
result = result.squeeze(2)
|
||||
|
||||
return result
|
||||
|
||||
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
|
||||
t = videos.size(temporal_dim)
|
||||
|
||||
if count == 0 and not prepend:
|
||||
if t % 4 == 1:
|
||||
return videos
|
||||
count = ((t - 1) // 4 + 1) * 4 + 1 - t
|
||||
|
||||
if count <= 0:
|
||||
return videos
|
||||
|
||||
def select(start, end):
|
||||
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
|
||||
|
||||
if count >= t:
|
||||
repeat_count = count - t + 1
|
||||
last = select(-1, None)
|
||||
|
||||
if temporal_dim == 0:
|
||||
repeated = last.repeat(repeat_count, 1, 1, 1)
|
||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
|
||||
else:
|
||||
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
|
||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
|
||||
|
||||
return torch.cat([repeated, reversed_frames, videos] if prepend else
|
||||
[videos, reversed_frames, repeated], dim=temporal_dim)
|
||||
|
||||
if prepend:
|
||||
reversed_frames = select(1, count+1).flip(temporal_dim)
|
||||
else:
|
||||
reversed_frames = select(-count-1, -1).flip(temporal_dim)
|
||||
|
||||
return torch.cat([reversed_frames, videos] if prepend else
|
||||
[videos, reversed_frames], dim=temporal_dim)
|
||||
|
||||
def clear_vae_memory(vae_model):
|
||||
for module in vae_model.modules():
|
||||
if hasattr(module, "memory"):
|
||||
module.memory = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def expand_dims(tensor, ndim):
|
||||
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
||||
return tensor.reshape(shape)
|
||||
|
||||
def get_conditions(latent, latent_blur):
|
||||
t, h, w, c = latent.shape
|
||||
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||
cond[:, ..., :-1] = latent_blur[:]
|
||||
cond[:, ..., -1:] = 1.0
|
||||
return cond
|
||||
|
||||
def timestep_transform(timesteps, latents_shapes):
|
||||
vt = 4
|
||||
vs = 8
|
||||
frames = (latents_shapes[:, 0] - 1) * vt + 1
|
||||
heights = latents_shapes[:, 1] * vs
|
||||
widths = latents_shapes[:, 2] * vs
|
||||
|
||||
# Compute shift factor.
|
||||
def get_lin_function(x1, y1, x2, y2):
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
|
||||
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
|
||||
shift = torch.where(
|
||||
frames > 1,
|
||||
vid_shift_fn(heights * widths * frames),
|
||||
img_shift_fn(heights * widths),
|
||||
).to(timesteps.device)
|
||||
|
||||
# Shift timesteps.
|
||||
T = 1000.0
|
||||
timesteps = timesteps / T
|
||||
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
|
||||
timesteps = timesteps * T
|
||||
return timesteps
|
||||
|
||||
def inter(x_0, x_T, t):
|
||||
t = expand_dims(t, x_0.ndim)
|
||||
T = 1000.0
|
||||
B = lambda t: t / T
|
||||
A = lambda t: 1 - (t / T)
|
||||
return A(t) * x_0 + B(t) * x_T
|
||||
def area_resize(image, max_area):
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
scale = math.sqrt(max_area / (height * width))
|
||||
|
||||
resized_height, resized_width = round(height * scale), round(width * scale)
|
||||
|
||||
return TVF.resize(
|
||||
image,
|
||||
size=(resized_height, resized_width),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
def div_pad(image, factor):
|
||||
|
||||
height_factor, width_factor = factor
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
pad_height = (height_factor - (height % height_factor)) % height_factor
|
||||
pad_width = (width_factor - (width % width_factor)) % width_factor
|
||||
|
||||
if pad_height == 0 and pad_width == 0:
|
||||
return image
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
padding = (0, pad_width, 0, pad_height)
|
||||
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
||||
|
||||
return image
|
||||
|
||||
def cut_videos(videos):
|
||||
t = videos.size(1)
|
||||
if t == 1:
|
||||
return videos
|
||||
if t <= 4 :
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
return videos
|
||||
if (t - 1) % (4) == 0:
|
||||
return videos
|
||||
else:
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (
|
||||
4 - ((t - 1) % (4))
|
||||
)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
assert (videos.size(1) - 1) % (4) == 0
|
||||
return videos
|
||||
|
||||
def side_resize(image, size):
|
||||
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
|
||||
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
|
||||
return resized
|
||||
|
||||
class SeedVR2InputProcessing(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id = "SeedVR2InputProcessing",
|
||||
category="image/video",
|
||||
inputs = [
|
||||
io.Image.Input("images"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
||||
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||
io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4),
|
||||
io.Boolean.Input("enable_tiling", default=False),
|
||||
],
|
||||
outputs = [
|
||||
io.Latent.Output("vae_conditioning")
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
||||
|
||||
comfy.model_management.load_models_gpu([vae.patcher])
|
||||
vae_model = vae.first_stage_model
|
||||
scale = 0.9152
|
||||
shift = 0
|
||||
if images.dim() != 5: # add the t dim
|
||||
images = images.unsqueeze(0)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
|
||||
b, t, c, h, w = images.shape
|
||||
images = images.reshape(b * t, c, h, w)
|
||||
|
||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||
normalize = Normalize(0.5, 0.5)
|
||||
images = side_resize(images, resolution)
|
||||
|
||||
images = clip(images)
|
||||
o_h, o_w = images.shape[-2:]
|
||||
images = div_pad(images, (16, 16))
|
||||
images = normalize(images)
|
||||
_, _, new_h, new_w = images.shape
|
||||
|
||||
images = images.reshape(b, t, c, new_h, new_w)
|
||||
images = cut_videos(images)
|
||||
|
||||
images = rearrange(images, "b t c h w -> b c t h w")
|
||||
|
||||
# in case users a non-compatiable number for tiling
|
||||
def make_divisible(val, divisor):
|
||||
return max(divisor, round(val / divisor) * divisor)
|
||||
|
||||
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
||||
spatial_overlap = make_divisible(spatial_overlap, 32)
|
||||
|
||||
if spatial_overlap >= spatial_tile_size:
|
||||
spatial_overlap = max(0, spatial_tile_size - 8)
|
||||
|
||||
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||
"temporal_size":temporal_tile_size}
|
||||
if enable_tiling:
|
||||
latent = tiled_vae(images, vae_model, encode=True, **args)
|
||||
else:
|
||||
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
|
||||
|
||||
clear_vae_memory(vae_model)
|
||||
#images = images.to(offload_device)
|
||||
#vae_model = vae_model.to(offload_device)
|
||||
|
||||
vae_model.img_dims = [o_h, o_w]
|
||||
args["enable_tiling"] = enable_tiling
|
||||
vae_model.tiled_args = args
|
||||
vae_model.original_image_video = images
|
||||
|
||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||
latent = rearrange(latent, "b c ... -> b ... c")
|
||||
|
||||
latent = (latent - shift) * scale
|
||||
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
class SeedVR2Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2Conditioning",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Latent.Input("vae_conditioning"),
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||
io.Conditioning.Output(display_name = "negative"),
|
||||
io.Latent.Output(display_name = "latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
|
||||
|
||||
vae_conditioning = vae_conditioning["samples"]
|
||||
device = vae_conditioning.device
|
||||
model = model.model.diffusion_model
|
||||
pos_cond = model.positive_conditioning
|
||||
neg_cond = model.negative_conditioning
|
||||
|
||||
noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||
cond_noise_scale = 0.0
|
||||
t = (
|
||||
torch.tensor([1000.0])
|
||||
* cond_noise_scale
|
||||
).to(device)
|
||||
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim
|
||||
t = timestep_transform(t, shape)
|
||||
cond = inter(vae_conditioning, aug_noises, t)
|
||||
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
|
||||
condition = condition.movedim(-1, 1)
|
||||
noises = noises.movedim(-1, 1)
|
||||
|
||||
pos_shape = pos_cond.shape[0]
|
||||
neg_shape = neg_cond.shape[0]
|
||||
diff = abs(pos_shape - neg_shape)
|
||||
if pos_shape > neg_shape:
|
||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||
else:
|
||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||
|
||||
noises = rearrange(noises, "b c t h w -> b (c t) h w")
|
||||
condition = rearrange(condition, "b c t h w -> b (c t) h w")
|
||||
|
||||
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||
|
||||
class SeedVRExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SeedVR2Conditioning,
|
||||
SeedVR2InputProcessing
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> SeedVRExtension:
|
||||
return SeedVRExtension()
|
||||
Loading…
Reference in New Issue
Block a user