mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 00:30:21 +08:00
Compare commits
52 Commits
d065cdf286
...
6c65b6b4bf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c65b6b4bf | ||
|
|
1a72bf2046 | ||
|
|
371c319cf9 | ||
|
|
5f625fcc78 | ||
|
|
c78bfda132 | ||
|
|
bc4fd2cd11 | ||
|
|
3c149dd543 | ||
|
|
fffb96ad42 | ||
|
|
a506be2486 | ||
|
|
dbf8f9dcf9 | ||
|
|
72ca18acc2 | ||
|
|
f588e6c821 | ||
|
|
0da072e098 | ||
|
|
31d358c78c | ||
|
|
4dd42ef1b7 | ||
|
|
02529c6d57 | ||
|
|
49febe15c3 | ||
|
|
84fa155071 | ||
|
|
4691717340 | ||
|
|
fadc7839cc | ||
|
|
3039c7ba14 | ||
|
|
9b573da39b | ||
|
|
4d7012ecda | ||
|
|
21bc67d7db | ||
|
|
7b2e5ef0af | ||
|
|
1afc2ed8e6 | ||
|
|
d41b1111eb | ||
|
|
5b0c80a093 | ||
|
|
e30298dda2 | ||
|
|
98b6bfcb71 | ||
|
|
fc5fabb629 | ||
|
|
5db5da790f | ||
|
|
a4e9d071e8 | ||
|
|
4fe772fae9 | ||
|
|
0d2044a778 | ||
|
|
7e62f8cc9f | ||
|
|
74621b9d86 | ||
|
|
db74a27870 | ||
|
|
acb9a11c6f | ||
|
|
d9f71da998 | ||
|
|
183b377588 | ||
|
|
ebd945ce3d | ||
|
|
58e7cea796 | ||
|
|
768c9cedf8 | ||
|
|
d629c8f910 | ||
|
|
413ee3f687 | ||
|
|
d12702ee0b | ||
|
|
f030b3afc8 | ||
|
|
44a5bf353a | ||
|
|
4b9332cc21 | ||
|
|
041dbd6a8a | ||
|
|
08d93555d0 |
@ -108,7 +108,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Works fully offline: core will never download anything unless you want to.
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview).
|
||||
- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes`
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
|
||||
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
@ -229,7 +229,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
@ -747,6 +747,10 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class SeedVR2(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 16
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
|
||||
|
||||
@ -19,9 +19,15 @@ if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
|
||||
SAGE_ATTENTION_IS_AVAILABLE = False
|
||||
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTENTION_IS_AVAILABLE = True
|
||||
try:
|
||||
from sageattention import sageattn_varlen
|
||||
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True
|
||||
except:
|
||||
pass
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention_enabled():
|
||||
if e.name == "sageattention":
|
||||
@ -39,7 +45,7 @@ except ImportError:
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.flash_attention_enabled():
|
||||
@ -87,7 +93,13 @@ def default(val, d):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||
@ -412,13 +424,14 @@ except:
|
||||
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
if b * heads > 65535 and not var_length:
|
||||
disabled_xformers = True
|
||||
|
||||
if not disabled_xformers:
|
||||
@ -426,9 +439,27 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs)
|
||||
|
||||
if skip_reshape:
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
total_tokens, hidden_dim = q.shape
|
||||
dim_head = hidden_dim // heads
|
||||
q = q.view(1, total_tokens, heads, dim_head)
|
||||
k = k.view(1, total_tokens, heads, dim_head)
|
||||
v = v.view(1, total_tokens, heads, dim_head)
|
||||
else:
|
||||
if q.ndim == 3:
|
||||
q = q.unsqueeze(0)
|
||||
if k.ndim == 3:
|
||||
k = k.unsqueeze(0)
|
||||
if v.ndim == 3:
|
||||
v = v.unsqueeze(0)
|
||||
dim_head = q.shape[-1]
|
||||
|
||||
target_output_shape = (q.shape[1], -1)
|
||||
b = 1
|
||||
elif skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
q, k, v = map(
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
@ -442,7 +473,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
if var_length:
|
||||
cu_seqlens_q, _, _, _ = var_attn_arg(kwargs)
|
||||
seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||
mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens)
|
||||
elif mask is not None:
|
||||
# add a singleton batch dimension
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
@ -464,6 +499,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if var_length:
|
||||
return out.reshape(*target_output_shape)
|
||||
if skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
@ -481,7 +518,28 @@ else:
|
||||
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
mask = None
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
elif skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
@ -499,8 +557,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
if SDP_BATCH_LIMIT >= b or var_length:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if var_length:
|
||||
return out.contiguous().transpose(1, 2).values()
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -524,8 +584,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
total_tokens, hidden_dim = q.shape
|
||||
dim_head = hidden_dim // heads
|
||||
q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)]
|
||||
b, _, dim_head = q.shape
|
||||
# skips batched code
|
||||
mask = None
|
||||
tensor_layout = "VAR"
|
||||
target_output_shape = (q.shape[0], -1)
|
||||
elif skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
else:
|
||||
@ -546,7 +617,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
try:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE:
|
||||
raise ValueError("Sage Attention two is required to run variable length attention.")
|
||||
elif var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
|
||||
out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale)
|
||||
else:
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
@ -556,7 +634,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs)
|
||||
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
@ -567,6 +645,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if skip_output_reshape:
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
if var_length:
|
||||
return out.view(*target_output_shape)
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@ -678,6 +758,15 @@ except AttributeError as error:
|
||||
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
return flash_attn_varlen_func(
|
||||
q=q, k=k, v=v,
|
||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False
|
||||
)
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
|
||||
@ -13,13 +13,14 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
return xl[0]
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
@ -30,11 +31,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
1482
comfy/ldm/seedvr/model.py
Normal file
1482
comfy/ldm/seedvr/model.py
Normal file
File diff suppressed because it is too large
Load Diff
2166
comfy/ldm/seedvr/vae.py
Normal file
2166
comfy/ldm/seedvr/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -47,6 +47,8 @@ import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.seedvr.model
|
||||
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
|
||||
@ -815,6 +817,16 @@ class HunyuanDiT(BaseModel):
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
condition = kwargs.get("condition", None)
|
||||
if condition is not None:
|
||||
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||
return out
|
||||
|
||||
class PixArt(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||
|
||||
@ -449,6 +449,28 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["mlp_type"] = "normal"
|
||||
return dit_config
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 2560
|
||||
dit_config["heads"] = 20
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["norm_eps"] = 1.0e-05
|
||||
dit_config["qk_rope"] = None
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
dit_config["vid_out_norm"] = True
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
|
||||
0
comfy/samplers.py
Executable file → Normal file
0
comfy/samplers.py
Executable file → Normal file
18
comfy/sd.py
18
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.seedvr.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
@ -312,7 +313,10 @@ class CLIP:
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
if (metadata is not None and metadata["keep_diffusers_format"] == "true"):
|
||||
pass
|
||||
else:
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
@ -379,6 +383,17 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.process_input = lambda image: image
|
||||
self.crop_input = False
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
@ -486,6 +501,7 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
|
||||
@ -1303,6 +1303,25 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix = "", device=None):
|
||||
out = model_base.SeedVR2(self, device=device)
|
||||
return out
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
unet_config = {
|
||||
"image_model": "chroma_radiance",
|
||||
@ -1551,6 +1570,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, SeedVR2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
466
comfy_extras/nodes_seedvr.py
Normal file
466
comfy_extras/nodes_seedvr.py
Normal file
@ -0,0 +1,466 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import torch
|
||||
import math
|
||||
from einops import rearrange
|
||||
|
||||
import gc
|
||||
import comfy.model_management
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.transforms import Lambda, Normalize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
x = x.to(next(vae_model.parameters()).dtype)
|
||||
if x.ndim != 5:
|
||||
x = x.unsqueeze(2)
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||
|
||||
if encode:
|
||||
ti_h, ti_w = tile_size
|
||||
ov_h, ov_w = tile_overlap
|
||||
target_d = (d + sf_t - 1) // sf_t
|
||||
target_h = (h + sf_s - 1) // sf_s
|
||||
target_w = (w + sf_s - 1) // sf_s
|
||||
else:
|
||||
ti_h = max(1, tile_size[0] // sf_s)
|
||||
ti_w = max(1, tile_size[1] // sf_s)
|
||||
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||
|
||||
target_d = d * sf_t
|
||||
target_h = h * sf_s
|
||||
target_w = w * sf_s
|
||||
|
||||
stride_h = max(1, ti_h - ov_h)
|
||||
stride_w = max(1, ti_w - ov_w)
|
||||
|
||||
storage_device = vae_model.device
|
||||
result = None
|
||||
count = None
|
||||
|
||||
def run_temporal_chunks(spatial_tile):
|
||||
chunk_results = []
|
||||
t_dim_size = spatial_tile.shape[2]
|
||||
|
||||
if encode:
|
||||
input_chunk = temporal_size
|
||||
else:
|
||||
input_chunk = max(1, temporal_size // sf_t)
|
||||
for i in range(0, t_dim_size, input_chunk):
|
||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||
current_valid_len = t_chunk.shape[2]
|
||||
|
||||
pad_amount = 0
|
||||
if current_valid_len < input_chunk:
|
||||
pad_amount = input_chunk - current_valid_len
|
||||
|
||||
last_frame = t_chunk[:, :, -1:, :, :]
|
||||
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
|
||||
|
||||
t_chunk = torch.cat([t_chunk, padding], dim=2)
|
||||
t_chunk = t_chunk.contiguous()
|
||||
|
||||
if encode:
|
||||
out = vae_model.encode(t_chunk)[0]
|
||||
else:
|
||||
out = vae_model.decode_(t_chunk)
|
||||
|
||||
if isinstance(out, (tuple, list)):
|
||||
out = out[0]
|
||||
if out.ndim == 4:
|
||||
out = out.unsqueeze(2)
|
||||
|
||||
if pad_amount > 0:
|
||||
if encode:
|
||||
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
|
||||
out = out[:, :, :expected_valid_out, :, :]
|
||||
|
||||
else:
|
||||
expected_valid_out = current_valid_len * sf_t
|
||||
out = out[:, :, :expected_valid_out, :, :]
|
||||
|
||||
chunk_results.append(out.to(storage_device))
|
||||
|
||||
return torch.cat(chunk_results, dim=2)
|
||||
|
||||
ramp_cache = {}
|
||||
def get_ramp(steps):
|
||||
if steps not in ramp_cache:
|
||||
t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32)
|
||||
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
||||
return ramp_cache[steps]
|
||||
|
||||
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
|
||||
bar = ProgressBar(total_tiles)
|
||||
|
||||
for y_idx in range(0, h, stride_h):
|
||||
y_end = min(y_idx + ti_h, h)
|
||||
|
||||
for x_idx in range(0, w, stride_w):
|
||||
x_end = min(x_idx + ti_w, w)
|
||||
|
||||
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
||||
|
||||
# Run VAE
|
||||
tile_out = run_temporal_chunks(tile_x)
|
||||
|
||||
if result is None:
|
||||
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||
|
||||
if encode:
|
||||
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||
else:
|
||||
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||
|
||||
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||
|
||||
if cur_ov_h > 0:
|
||||
r = get_ramp(cur_ov_h)
|
||||
if y_idx > 0:
|
||||
w_h[:cur_ov_h] = r
|
||||
if y_end < h:
|
||||
w_h[-cur_ov_h:] = 1.0 - r
|
||||
|
||||
if cur_ov_w > 0:
|
||||
r = get_ramp(cur_ov_w)
|
||||
if x_idx > 0:
|
||||
w_w[:cur_ov_w] = r
|
||||
if x_end < w:
|
||||
w_w[-cur_ov_w:] = 1.0 - r
|
||||
|
||||
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||
|
||||
valid_d = min(tile_out.shape[2], result.shape[2])
|
||||
tile_out = tile_out[:, :, :valid_d, :, :]
|
||||
|
||||
tile_out.mul_(final_weight)
|
||||
|
||||
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
||||
count[:, :, :, ys:ye, xs:xe] += final_weight
|
||||
|
||||
del tile_out, final_weight, tile_x, w_h, w_w
|
||||
bar.update(1)
|
||||
|
||||
result.div_(count.clamp(min=1e-6))
|
||||
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
|
||||
if x.shape[2] == 1 and sf_t == 1:
|
||||
result = result.squeeze(2)
|
||||
|
||||
return result
|
||||
|
||||
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
|
||||
t = videos.size(temporal_dim)
|
||||
|
||||
if count == 0 and not prepend:
|
||||
if t % 4 == 1:
|
||||
return videos
|
||||
count = ((t - 1) // 4 + 1) * 4 + 1 - t
|
||||
|
||||
if count <= 0:
|
||||
return videos
|
||||
|
||||
def select(start, end):
|
||||
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
|
||||
|
||||
if count >= t:
|
||||
repeat_count = count - t + 1
|
||||
last = select(-1, None)
|
||||
|
||||
if temporal_dim == 0:
|
||||
repeated = last.repeat(repeat_count, 1, 1, 1)
|
||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
|
||||
else:
|
||||
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
|
||||
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
|
||||
|
||||
return torch.cat([repeated, reversed_frames, videos] if prepend else
|
||||
[videos, reversed_frames, repeated], dim=temporal_dim)
|
||||
|
||||
if prepend:
|
||||
reversed_frames = select(1, count+1).flip(temporal_dim)
|
||||
else:
|
||||
reversed_frames = select(-count-1, -1).flip(temporal_dim)
|
||||
|
||||
return torch.cat([reversed_frames, videos] if prepend else
|
||||
[videos, reversed_frames], dim=temporal_dim)
|
||||
|
||||
def clear_vae_memory(vae_model):
|
||||
for module in vae_model.modules():
|
||||
if hasattr(module, "memory"):
|
||||
module.memory = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def expand_dims(tensor, ndim):
|
||||
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
||||
return tensor.reshape(shape)
|
||||
|
||||
def get_conditions(latent, latent_blur):
|
||||
t, h, w, c = latent.shape
|
||||
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||
cond[:, ..., :-1] = latent_blur[:]
|
||||
cond[:, ..., -1:] = 1.0
|
||||
return cond
|
||||
|
||||
def timestep_transform(timesteps, latents_shapes):
|
||||
vt = 4
|
||||
vs = 8
|
||||
frames = (latents_shapes[:, 0] - 1) * vt + 1
|
||||
heights = latents_shapes[:, 1] * vs
|
||||
widths = latents_shapes[:, 2] * vs
|
||||
|
||||
# Compute shift factor.
|
||||
def get_lin_function(x1, y1, x2, y2):
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
|
||||
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
|
||||
shift = torch.where(
|
||||
frames > 1,
|
||||
vid_shift_fn(heights * widths * frames),
|
||||
img_shift_fn(heights * widths),
|
||||
).to(timesteps.device)
|
||||
|
||||
# Shift timesteps.
|
||||
T = 1000.0
|
||||
timesteps = timesteps / T
|
||||
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
|
||||
timesteps = timesteps * T
|
||||
return timesteps
|
||||
|
||||
def inter(x_0, x_T, t):
|
||||
t = expand_dims(t, x_0.ndim)
|
||||
T = 1000.0
|
||||
B = lambda t: t / T
|
||||
A = lambda t: 1 - (t / T)
|
||||
return A(t) * x_0 + B(t) * x_T
|
||||
def area_resize(image, max_area):
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
scale = math.sqrt(max_area / (height * width))
|
||||
|
||||
resized_height, resized_width = round(height * scale), round(width * scale)
|
||||
|
||||
return TVF.resize(
|
||||
image,
|
||||
size=(resized_height, resized_width),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
def div_pad(image, factor):
|
||||
|
||||
height_factor, width_factor = factor
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
pad_height = (height_factor - (height % height_factor)) % height_factor
|
||||
pad_width = (width_factor - (width % width_factor)) % width_factor
|
||||
|
||||
if pad_height == 0 and pad_width == 0:
|
||||
return image
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
padding = (0, pad_width, 0, pad_height)
|
||||
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
||||
|
||||
return image
|
||||
|
||||
def cut_videos(videos):
|
||||
t = videos.size(1)
|
||||
if t == 1:
|
||||
return videos
|
||||
if t <= 4 :
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
return videos
|
||||
if (t - 1) % (4) == 0:
|
||||
return videos
|
||||
else:
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (
|
||||
4 - ((t - 1) % (4))
|
||||
)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
assert (videos.size(1) - 1) % (4) == 0
|
||||
return videos
|
||||
|
||||
def side_resize(image, size):
|
||||
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
|
||||
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
|
||||
return resized
|
||||
|
||||
class SeedVR2InputProcessing(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id = "SeedVR2InputProcessing",
|
||||
category="image/video",
|
||||
inputs = [
|
||||
io.Image.Input("images"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
||||
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||
io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4),
|
||||
io.Boolean.Input("enable_tiling", default=False),
|
||||
],
|
||||
outputs = [
|
||||
io.Latent.Output("vae_conditioning")
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
|
||||
|
||||
comfy.model_management.load_models_gpu([vae.patcher])
|
||||
vae_model = vae.first_stage_model
|
||||
scale = 0.9152
|
||||
shift = 0
|
||||
if images.dim() != 5: # add the t dim
|
||||
images = images.unsqueeze(0)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
|
||||
b, t, c, h, w = images.shape
|
||||
images = images.reshape(b * t, c, h, w)
|
||||
|
||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||
normalize = Normalize(0.5, 0.5)
|
||||
images = side_resize(images, resolution)
|
||||
|
||||
images = clip(images)
|
||||
o_h, o_w = images.shape[-2:]
|
||||
images = div_pad(images, (16, 16))
|
||||
images = normalize(images)
|
||||
_, _, new_h, new_w = images.shape
|
||||
|
||||
images = images.reshape(b, t, c, new_h, new_w)
|
||||
images = cut_videos(images)
|
||||
|
||||
images = rearrange(images, "b t c h w -> b c t h w")
|
||||
|
||||
# in case users a non-compatiable number for tiling
|
||||
def make_divisible(val, divisor):
|
||||
return max(divisor, round(val / divisor) * divisor)
|
||||
|
||||
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
||||
spatial_overlap = make_divisible(spatial_overlap, 32)
|
||||
|
||||
if spatial_overlap >= spatial_tile_size:
|
||||
spatial_overlap = max(0, spatial_tile_size - 8)
|
||||
|
||||
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||
"temporal_size":temporal_tile_size}
|
||||
if enable_tiling:
|
||||
latent = tiled_vae(images, vae_model, encode=True, **args)
|
||||
else:
|
||||
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
|
||||
|
||||
clear_vae_memory(vae_model)
|
||||
#images = images.to(offload_device)
|
||||
#vae_model = vae_model.to(offload_device)
|
||||
|
||||
vae_model.img_dims = [o_h, o_w]
|
||||
args["enable_tiling"] = enable_tiling
|
||||
vae_model.tiled_args = args
|
||||
vae_model.original_image_video = images
|
||||
|
||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||
latent = rearrange(latent, "b c ... -> b ... c")
|
||||
|
||||
latent = (latent - shift) * scale
|
||||
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
class SeedVR2Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2Conditioning",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Latent.Input("vae_conditioning"),
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("latent_noise_scale", default=0.0, step=0.001)
|
||||
],
|
||||
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||
io.Conditioning.Output(display_name = "negative"),
|
||||
io.Latent.Output(display_name = "latent")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput:
|
||||
|
||||
vae_conditioning = vae_conditioning["samples"]
|
||||
device = vae_conditioning.device
|
||||
model = model.model.diffusion_model
|
||||
pos_cond = model.positive_conditioning
|
||||
neg_cond = model.negative_conditioning
|
||||
|
||||
noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||
cond_noise_scale = latent_noise_scale
|
||||
t = (
|
||||
torch.tensor([1000.0])
|
||||
* cond_noise_scale
|
||||
).to(device)
|
||||
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim
|
||||
t = timestep_transform(t, shape)
|
||||
cond = inter(vae_conditioning, aug_noises, t)
|
||||
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
|
||||
condition = condition.movedim(-1, 1)
|
||||
noises = noises.movedim(-1, 1)
|
||||
|
||||
pos_shape = pos_cond.shape[0]
|
||||
neg_shape = neg_cond.shape[0]
|
||||
diff = abs(pos_shape - neg_shape)
|
||||
if pos_shape > neg_shape:
|
||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||
else:
|
||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||
|
||||
noises = rearrange(noises, "b c t h w -> b (c t) h w")
|
||||
condition = rearrange(condition, "b c t h w -> b (c t) h w")
|
||||
|
||||
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||
|
||||
class SeedVRExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SeedVR2Conditioning,
|
||||
SeedVR2InputProcessing
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> SeedVRExtension:
|
||||
return SeedVRExtension()
|
||||
Loading…
Reference in New Issue
Block a user