mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-05 08:16:47 +08:00
Merge 553f71aa9e into fc1fdf3389
This commit is contained in:
commit
ce0c979e96
@ -755,6 +755,10 @@ class ACEAudio(LatentFormat):
|
|||||||
latent_channels = 8
|
latent_channels = 8
|
||||||
latent_dimensions = 2
|
latent_dimensions = 2
|
||||||
|
|
||||||
|
class SeedVR2(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
latent_dimensions = 16
|
||||||
|
|
||||||
class ACEAudio15(LatentFormat):
|
class ACEAudio15(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -719,7 +719,30 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
if not skip_reshape:
|
||||||
|
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
q = q.view(total_tokens, heads, head_dim)
|
||||||
|
k = k.view(k.shape[0], heads, head_dim)
|
||||||
|
v = v.view(v.shape[0], heads, head_dim)
|
||||||
|
|
||||||
|
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||||
|
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||||
|
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
if not skip_output_reshape:
|
||||||
|
return out.values().reshape(-1, heads * (q.shape[-1]))
|
||||||
|
return out.values()
|
||||||
|
|
||||||
|
optimized_var_attention = var_attention_pytorch
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
|
|||||||
@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
|
|
||||||
def torch_cat_if_needed(xl, dim):
|
def torch_cat_if_needed(xl, dim):
|
||||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||||
if len(xl) > 1:
|
if len(xl) > 1:
|
||||||
@ -22,7 +23,7 @@ def torch_cat_if_needed(xl, dim):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
From Fairseq.
|
From Fairseq.
|
||||||
@ -33,11 +34,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
assert len(timesteps.shape) == 1
|
assert len(timesteps.shape) == 1
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||||
emb = emb.to(device=timesteps.device)
|
emb = emb.to(device=timesteps.device)
|
||||||
emb = timesteps.float()[:, None] * emb[None, :]
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
if embedding_dim % 2 == 1: # zero pad
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||||
return emb
|
return emb
|
||||||
|
|||||||
1480
comfy/ldm/seedvr/model.py
Normal file
1480
comfy/ldm/seedvr/model.py
Normal file
File diff suppressed because it is too large
Load Diff
2145
comfy/ldm/seedvr/vae.py
Normal file
2145
comfy/ldm/seedvr/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -48,6 +48,8 @@ import comfy.ldm.chroma.model
|
|||||||
import comfy.ldm.chroma_radiance.model
|
import comfy.ldm.chroma_radiance.model
|
||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
|
import comfy.ldm.seedvr.model
|
||||||
|
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
import comfy.ldm.kandinsky5.model
|
import comfy.ldm.kandinsky5.model
|
||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
@ -828,6 +830,16 @@ class HunyuanDiT(BaseModel):
|
|||||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class SeedVR2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
condition = kwargs.get("condition", None)
|
||||||
|
if condition is not None:
|
||||||
|
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||||
|
return out
|
||||||
|
|
||||||
class PixArt(BaseModel):
|
class PixArt(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||||
|
|||||||
@ -490,6 +490,28 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 3072
|
||||||
|
dit_config["heads"] = 24
|
||||||
|
dit_config["num_layers"] = 36
|
||||||
|
dit_config["norm_eps"] = 1e-5
|
||||||
|
dit_config["qk_rope"] = True
|
||||||
|
dit_config["mlp_type"] = "normal"
|
||||||
|
return dit_config
|
||||||
|
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 2560
|
||||||
|
dit_config["heads"] = 20
|
||||||
|
dit_config["num_layers"] = 32
|
||||||
|
dit_config["norm_eps"] = 1.0e-05
|
||||||
|
dit_config["qk_rope"] = None
|
||||||
|
dit_config["mlp_type"] = "swiglu"
|
||||||
|
dit_config["vid_out_norm"] = True
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "wan2.1"
|
dit_config["image_model"] = "wan2.1"
|
||||||
|
|||||||
0
comfy/samplers.py
Executable file → Normal file
0
comfy/samplers.py
Executable file → Normal file
16
comfy/sd.py
16
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
|||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
import comfy.ldm.wan.vae2_2
|
import comfy.ldm.wan.vae2_2
|
||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
|
import comfy.ldm.seedvr.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
import comfy.ldm.hunyuan_video.vae
|
import comfy.ldm.hunyuan_video.vae
|
||||||
import comfy.ldm.mmaudio.vae.autoencoder
|
import comfy.ldm.mmaudio.vae.autoencoder
|
||||||
@ -437,7 +438,8 @@ class CLIP:
|
|||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||||
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
if model_management.is_amd():
|
if model_management.is_amd():
|
||||||
VAE_KL_MEM_RATIO = 2.73
|
VAE_KL_MEM_RATIO = 2.73
|
||||||
@ -506,6 +508,17 @@ class VAE:
|
|||||||
self.first_stage_model = StageC_coder()
|
self.first_stage_model = StageC_coder()
|
||||||
self.downscale_ratio = 32
|
self.downscale_ratio = 32
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
|
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||||
|
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
self.upscale_index_formula = (4, 8, 8)
|
||||||
|
self.process_input = lambda image: image
|
||||||
|
self.crop_input = False
|
||||||
elif "decoder.conv_in.weight" in sd:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
@ -626,6 +639,7 @@ class VAE:
|
|||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||||
self.downscale_index_formula = (8, 32, 32)
|
self.downscale_index_formula = (8, 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
|
|||||||
@ -1415,6 +1415,25 @@ class Chroma(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||||
|
|
||||||
|
class SeedVR2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "seedvr2"
|
||||||
|
}
|
||||||
|
latent_format = comfy.latent_formats.SeedVR2
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix = "", device=None):
|
||||||
|
out = model_base.SeedVR2(self, device=device)
|
||||||
|
return out
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None
|
||||||
|
|
||||||
class ChromaRadiance(Chroma):
|
class ChromaRadiance(Chroma):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "chroma_radiance",
|
"image_model": "chroma_radiance",
|
||||||
@ -1734,6 +1753,6 @@ class LongCatImage(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, SeedVR2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
470
comfy_extras/nodes_seedvr.py
Normal file
470
comfy_extras/nodes_seedvr.py
Normal file
@ -0,0 +1,470 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision.transforms import functional as TVF
|
||||||
|
from torchvision.transforms import Lambda, Normalize
|
||||||
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
x = x.to(next(vae_model.parameters()).dtype)
|
||||||
|
if x.ndim != 5:
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
|
||||||
|
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||||
|
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
ti_h, ti_w = tile_size
|
||||||
|
ov_h, ov_w = tile_overlap
|
||||||
|
target_d = (d + sf_t - 1) // sf_t
|
||||||
|
target_h = (h + sf_s - 1) // sf_s
|
||||||
|
target_w = (w + sf_s - 1) // sf_s
|
||||||
|
else:
|
||||||
|
ti_h = max(1, tile_size[0] // sf_s)
|
||||||
|
ti_w = max(1, tile_size[1] // sf_s)
|
||||||
|
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||||
|
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||||
|
|
||||||
|
target_d = d * sf_t
|
||||||
|
target_h = h * sf_s
|
||||||
|
target_w = w * sf_s
|
||||||
|
|
||||||
|
stride_h = max(1, ti_h - ov_h)
|
||||||
|
stride_w = max(1, ti_w - ov_w)
|
||||||
|
|
||||||
|
storage_device = vae_model.device
|
||||||
|
result = None
|
||||||
|
count = None
|
||||||
|
|
||||||
|
def run_temporal_chunks(spatial_tile):
|
||||||
|
chunk_results = []
|
||||||
|
t_dim_size = spatial_tile.shape[2]
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
input_chunk = temporal_size
|
||||||
|
else:
|
||||||
|
input_chunk = max(1, temporal_size // sf_t)
|
||||||
|
for i in range(0, t_dim_size, input_chunk):
|
||||||
|
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||||
|
current_valid_len = t_chunk.shape[2]
|
||||||
|
|
||||||
|
pad_amount = 0
|
||||||
|
if current_valid_len < input_chunk:
|
||||||
|
pad_amount = input_chunk - current_valid_len
|
||||||
|
|
||||||
|
last_frame = t_chunk[:, :, -1:, :, :]
|
||||||
|
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
|
||||||
|
|
||||||
|
t_chunk = torch.cat([t_chunk, padding], dim=2)
|
||||||
|
t_chunk = t_chunk.contiguous()
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
out = vae_model.encode(t_chunk)[0]
|
||||||
|
else:
|
||||||
|
out = vae_model.decode_(t_chunk)
|
||||||
|
|
||||||
|
if isinstance(out, (tuple, list)):
|
||||||
|
out = out[0]
|
||||||
|
if out.ndim == 4:
|
||||||
|
out = out.unsqueeze(2)
|
||||||
|
|
||||||
|
if pad_amount > 0:
|
||||||
|
if encode:
|
||||||
|
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
|
||||||
|
out = out[:, :, :expected_valid_out, :, :]
|
||||||
|
|
||||||
|
else:
|
||||||
|
expected_valid_out = current_valid_len * sf_t
|
||||||
|
out = out[:, :, :expected_valid_out, :, :]
|
||||||
|
|
||||||
|
chunk_results.append(out.to(storage_device))
|
||||||
|
|
||||||
|
return torch.cat(chunk_results, dim=2)
|
||||||
|
|
||||||
|
ramp_cache = {}
|
||||||
|
def get_ramp(steps):
|
||||||
|
if steps not in ramp_cache:
|
||||||
|
t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32)
|
||||||
|
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
|
||||||
|
return ramp_cache[steps]
|
||||||
|
|
||||||
|
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
|
||||||
|
bar = ProgressBar(total_tiles)
|
||||||
|
|
||||||
|
for y_idx in range(0, h, stride_h):
|
||||||
|
y_end = min(y_idx + ti_h, h)
|
||||||
|
|
||||||
|
for x_idx in range(0, w, stride_w):
|
||||||
|
x_end = min(x_idx + ti_w, w)
|
||||||
|
|
||||||
|
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
||||||
|
|
||||||
|
# Run VAE
|
||||||
|
tile_out = run_temporal_chunks(tile_x)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
|
||||||
|
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
|
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if encode:
|
||||||
|
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||||
|
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||||
|
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||||
|
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||||
|
else:
|
||||||
|
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
|
||||||
|
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
|
||||||
|
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
|
||||||
|
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
|
||||||
|
|
||||||
|
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
|
||||||
|
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
|
||||||
|
|
||||||
|
if cur_ov_h > 0:
|
||||||
|
r = get_ramp(cur_ov_h)
|
||||||
|
if y_idx > 0:
|
||||||
|
w_h[:cur_ov_h] = r
|
||||||
|
if y_end < h:
|
||||||
|
w_h[-cur_ov_h:] = 1.0 - r
|
||||||
|
|
||||||
|
if cur_ov_w > 0:
|
||||||
|
r = get_ramp(cur_ov_w)
|
||||||
|
if x_idx > 0:
|
||||||
|
w_w[:cur_ov_w] = r
|
||||||
|
if x_end < w:
|
||||||
|
w_w[-cur_ov_w:] = 1.0 - r
|
||||||
|
|
||||||
|
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||||
|
|
||||||
|
valid_d = min(tile_out.shape[2], result.shape[2])
|
||||||
|
tile_out = tile_out[:, :, :valid_d, :, :]
|
||||||
|
|
||||||
|
tile_out.mul_(final_weight)
|
||||||
|
|
||||||
|
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
||||||
|
count[:, :, :, ys:ye, xs:xe] += final_weight
|
||||||
|
|
||||||
|
del tile_out, final_weight, tile_x, w_h, w_w
|
||||||
|
bar.update(1)
|
||||||
|
|
||||||
|
result.div_(count.clamp(min=1e-6))
|
||||||
|
|
||||||
|
if result.device != x.device:
|
||||||
|
result = result.to(x.device).to(x.dtype)
|
||||||
|
|
||||||
|
if x.shape[2] == 1 and sf_t == 1:
|
||||||
|
result = result.squeeze(2)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
|
||||||
|
t = videos.size(temporal_dim)
|
||||||
|
|
||||||
|
if count == 0 and not prepend:
|
||||||
|
if t % 4 == 1:
|
||||||
|
return videos
|
||||||
|
count = ((t - 1) // 4 + 1) * 4 + 1 - t
|
||||||
|
|
||||||
|
if count <= 0:
|
||||||
|
return videos
|
||||||
|
|
||||||
|
def select(start, end):
|
||||||
|
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
|
||||||
|
|
||||||
|
if count >= t:
|
||||||
|
repeat_count = count - t + 1
|
||||||
|
last = select(-1, None)
|
||||||
|
|
||||||
|
if temporal_dim == 0:
|
||||||
|
repeated = last.repeat(repeat_count, 1, 1, 1)
|
||||||
|
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
|
||||||
|
else:
|
||||||
|
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
|
||||||
|
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
|
||||||
|
|
||||||
|
return torch.cat([repeated, reversed_frames, videos] if prepend else
|
||||||
|
[videos, reversed_frames, repeated], dim=temporal_dim)
|
||||||
|
|
||||||
|
if prepend:
|
||||||
|
reversed_frames = select(1, count+1).flip(temporal_dim)
|
||||||
|
else:
|
||||||
|
reversed_frames = select(-count-1, -1).flip(temporal_dim)
|
||||||
|
|
||||||
|
return torch.cat([reversed_frames, videos] if prepend else
|
||||||
|
[videos, reversed_frames], dim=temporal_dim)
|
||||||
|
|
||||||
|
def clear_vae_memory(vae_model):
|
||||||
|
for module in vae_model.modules():
|
||||||
|
if hasattr(module, "memory"):
|
||||||
|
module.memory = None
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def expand_dims(tensor, ndim):
|
||||||
|
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
||||||
|
return tensor.reshape(shape)
|
||||||
|
|
||||||
|
def get_conditions(latent, latent_blur):
|
||||||
|
t, h, w, c = latent.shape
|
||||||
|
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||||
|
cond[:, ..., :-1] = latent_blur[:]
|
||||||
|
cond[:, ..., -1:] = 1.0
|
||||||
|
return cond
|
||||||
|
|
||||||
|
def timestep_transform(timesteps, latents_shapes):
|
||||||
|
vt = 4
|
||||||
|
vs = 8
|
||||||
|
frames = (latents_shapes[:, 0] - 1) * vt + 1
|
||||||
|
heights = latents_shapes[:, 1] * vs
|
||||||
|
widths = latents_shapes[:, 2] * vs
|
||||||
|
|
||||||
|
# Compute shift factor.
|
||||||
|
def get_lin_function(x1, y1, x2, y2):
|
||||||
|
m = (y2 - y1) / (x2 - x1)
|
||||||
|
b = y1 - m * x1
|
||||||
|
return lambda x: m * x + b
|
||||||
|
|
||||||
|
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
|
||||||
|
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
|
||||||
|
shift = torch.where(
|
||||||
|
frames > 1,
|
||||||
|
vid_shift_fn(heights * widths * frames),
|
||||||
|
img_shift_fn(heights * widths),
|
||||||
|
).to(timesteps.device)
|
||||||
|
|
||||||
|
# Shift timesteps.
|
||||||
|
T = 1000.0
|
||||||
|
timesteps = timesteps / T
|
||||||
|
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
|
||||||
|
timesteps = timesteps * T
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def inter(x_0, x_T, t):
|
||||||
|
t = expand_dims(t, x_0.ndim)
|
||||||
|
T = 1000.0
|
||||||
|
B = lambda t: t / T
|
||||||
|
A = lambda t: 1 - (t / T)
|
||||||
|
return A(t) * x_0 + B(t) * x_T
|
||||||
|
def area_resize(image, max_area):
|
||||||
|
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
scale = math.sqrt(max_area / (height * width))
|
||||||
|
|
||||||
|
resized_height, resized_width = round(height * scale), round(width * scale)
|
||||||
|
|
||||||
|
return TVF.resize(
|
||||||
|
image,
|
||||||
|
size=(resized_height, resized_width),
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
)
|
||||||
|
|
||||||
|
def div_pad(image, factor):
|
||||||
|
|
||||||
|
height_factor, width_factor = factor
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
|
||||||
|
pad_height = (height_factor - (height % height_factor)) % height_factor
|
||||||
|
pad_width = (width_factor - (width % width_factor)) % width_factor
|
||||||
|
|
||||||
|
if pad_height == 0 and pad_width == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
padding = (0, pad_width, 0, pad_height)
|
||||||
|
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def cut_videos(videos):
|
||||||
|
t = videos.size(1)
|
||||||
|
if t == 1:
|
||||||
|
return videos
|
||||||
|
if t <= 4 :
|
||||||
|
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
|
||||||
|
padding = torch.cat(padding, dim=1)
|
||||||
|
videos = torch.cat([videos, padding], dim=1)
|
||||||
|
return videos
|
||||||
|
if (t - 1) % (4) == 0:
|
||||||
|
return videos
|
||||||
|
else:
|
||||||
|
padding = [videos[:, -1].unsqueeze(1)] * (
|
||||||
|
4 - ((t - 1) % (4))
|
||||||
|
)
|
||||||
|
padding = torch.cat(padding, dim=1)
|
||||||
|
videos = torch.cat([videos, padding], dim=1)
|
||||||
|
assert (videos.size(1) - 1) % (4) == 0
|
||||||
|
return videos
|
||||||
|
|
||||||
|
def side_resize(image, size):
|
||||||
|
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
|
||||||
|
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
|
||||||
|
return resized
|
||||||
|
|
||||||
|
class SeedVR2InputProcessing(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id = "SeedVR2InputProcessing",
|
||||||
|
category="image/video",
|
||||||
|
inputs = [
|
||||||
|
io.Image.Input("images"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
|
||||||
|
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||||
|
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||||
|
io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4),
|
||||||
|
io.Boolean.Input("enable_tiling", default=False),
|
||||||
|
],
|
||||||
|
outputs = [
|
||||||
|
io.Latent.Output("vae_conditioning")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, images, vae, resolution, spatial_tile_size, spatial_overlap, temporal_tile_size, enable_tiling):
|
||||||
|
|
||||||
|
comfy.model_management.load_models_gpu([vae.patcher])
|
||||||
|
vae_model = vae.first_stage_model
|
||||||
|
scale = 0.9152
|
||||||
|
shift = 0
|
||||||
|
if images.dim() != 5: # add the t dim
|
||||||
|
images = images.unsqueeze(0)
|
||||||
|
images = images.permute(0, 1, 4, 2, 3)
|
||||||
|
|
||||||
|
b, t, c, h, w = images.shape
|
||||||
|
images = images.reshape(b * t, c, h, w)
|
||||||
|
|
||||||
|
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||||
|
normalize = Normalize(0.5, 0.5)
|
||||||
|
images = side_resize(images, resolution)
|
||||||
|
|
||||||
|
images = clip(images)
|
||||||
|
o_h, o_w = images.shape[-2:]
|
||||||
|
images = div_pad(images, (16, 16))
|
||||||
|
images = normalize(images)
|
||||||
|
_, _, new_h, new_w = images.shape
|
||||||
|
|
||||||
|
images = images.reshape(b, t, c, new_h, new_w)
|
||||||
|
images = cut_videos(images)
|
||||||
|
|
||||||
|
images = rearrange(images, "b t c h w -> b c t h w")
|
||||||
|
|
||||||
|
# in case users a non-compatiable number for tiling
|
||||||
|
def make_divisible(val, divisor):
|
||||||
|
return max(divisor, round(val / divisor) * divisor)
|
||||||
|
|
||||||
|
spatial_tile_size = make_divisible(spatial_tile_size, 32)
|
||||||
|
spatial_overlap = make_divisible(spatial_overlap, 32)
|
||||||
|
|
||||||
|
if spatial_overlap >= spatial_tile_size:
|
||||||
|
spatial_overlap = max(0, spatial_tile_size - 8)
|
||||||
|
|
||||||
|
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||||
|
"temporal_size":temporal_tile_size}
|
||||||
|
if enable_tiling:
|
||||||
|
latent = tiled_vae(images, vae_model, encode=True, **args)
|
||||||
|
else:
|
||||||
|
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
|
||||||
|
|
||||||
|
clear_vae_memory(vae_model)
|
||||||
|
#images = images.to(offload_device)
|
||||||
|
#vae_model = vae_model.to(offload_device)
|
||||||
|
|
||||||
|
vae_model.img_dims = [o_h, o_w]
|
||||||
|
args["enable_tiling"] = enable_tiling
|
||||||
|
vae_model.tiled_args = args
|
||||||
|
vae_model.original_image_video = images
|
||||||
|
|
||||||
|
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||||
|
latent = rearrange(latent, "b c ... -> b ... c")
|
||||||
|
|
||||||
|
latent = (latent - shift) * scale
|
||||||
|
|
||||||
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
|
class SeedVR2Conditioning(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SeedVR2Conditioning",
|
||||||
|
category="image/video",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("vae_conditioning"),
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("latent_noise_scale", default=0.0, step=0.001)
|
||||||
|
],
|
||||||
|
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||||
|
io.Conditioning.Output(display_name = "negative"),
|
||||||
|
io.Latent.Output(display_name = "latent")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, vae_conditioning, model, latent_noise_scale) -> io.NodeOutput:
|
||||||
|
|
||||||
|
vae_conditioning = vae_conditioning["samples"]
|
||||||
|
device = vae_conditioning.device
|
||||||
|
model = model.model.diffusion_model
|
||||||
|
pos_cond = model.positive_conditioning
|
||||||
|
neg_cond = model.negative_conditioning
|
||||||
|
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
|
||||||
|
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
|
||||||
|
|
||||||
|
noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
|
||||||
|
aug_noises = torch.randn_like(vae_conditioning, dtype=vae_conditioning.dtype).to(device)
|
||||||
|
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||||
|
cond_noise_scale = latent_noise_scale
|
||||||
|
t = (
|
||||||
|
torch.tensor([1000.0])
|
||||||
|
* cond_noise_scale
|
||||||
|
).to(device)
|
||||||
|
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim
|
||||||
|
t = timestep_transform(t, shape)
|
||||||
|
cond = inter(vae_conditioning, aug_noises, t)
|
||||||
|
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
|
||||||
|
condition = condition.movedim(-1, 1)
|
||||||
|
noises = noises.movedim(-1, 1)
|
||||||
|
|
||||||
|
pos_shape = pos_cond.shape[0]
|
||||||
|
neg_shape = neg_cond.shape[0]
|
||||||
|
diff = abs(pos_shape - neg_shape)
|
||||||
|
if pos_shape > neg_shape:
|
||||||
|
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||||
|
else:
|
||||||
|
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||||
|
|
||||||
|
noises = rearrange(noises, "b c t h w -> b (c t) h w")
|
||||||
|
condition = rearrange(condition, "b c t h w -> b (c t) h w")
|
||||||
|
|
||||||
|
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
|
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
|
|
||||||
|
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||||
|
|
||||||
|
class SeedVRExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SeedVR2Conditioning,
|
||||||
|
SeedVR2InputProcessing
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> SeedVRExtension:
|
||||||
|
return SeedVRExtension()
|
||||||
1
nodes.py
1
nodes.py
@ -2431,6 +2431,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_camera_trajectory.py",
|
"nodes_camera_trajectory.py",
|
||||||
"nodes_edit_model.py",
|
"nodes_edit_model.py",
|
||||||
"nodes_tcfg.py",
|
"nodes_tcfg.py",
|
||||||
|
"nodes_seedvr.py",
|
||||||
"nodes_context_windows.py",
|
"nodes_context_windows.py",
|
||||||
"nodes_qwen.py",
|
"nodes_qwen.py",
|
||||||
"nodes_chroma_radiance.py",
|
"nodes_chroma_radiance.py",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user