diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 86836468f..42567fa30 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1187,8 +1187,12 @@ class NaDiT(nn.Module): rope_dim = 128, rope_type = "mmrope3d", vid_out_norm: Optional[str] = None, + device = None, + dtype = None, + operations = None, **kwargs, ): + self.dtype = dtype window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] txt_dim = vid_dim emb_dim = vid_dim * 6 @@ -1292,33 +1296,33 @@ class NaDiT(nn.Module): x, timestep, context, # l c - txt_shape, # b 1 disable_cache: bool = True, # for test # TODO ? + **kwargs ): + transformer_options = kwargs.get("transformer_options", {}) + c_or_u_list = transformer_options.get("cond_or_uncond", []) + cond_latent = c_or_u_list[0]["condition"] + pos_cond, neg_cond = context.chunk(2, dim=0) - pos_cond, pos_shape = flatten(pos_cond) - neg_cond, neg_shape = flatten(neg_cond) - diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) - if pos_shape.shape[1] > neg_shape.shape[1]: - neg_shape = F.pad(neg_shape, (0, 0, 0, diff)) - neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) - else: - pos_shape = F.pad(pos_shape, (0, 0, 0, diff)) - pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) + # txt_shape should be the same for both + pos_cond, txt_shape = flatten(pos_cond) + neg_cond, _ = flatten(neg_cond) + txt = torch.cat([pos_cond, neg_cond], dim = 0) + txt_shape[0] *= 2 + vid = x - txt = context vid, vid_shape = flatten(x) + + vid = torch.cat([cond_latent, vid]) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - # slice vid after patching in when using sequence parallelism + txt = self.txt_in(txt) vid, vid_shape = self.vid_in(vid, vid_shape) - # Embedding input. emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - # Body cache = Cache(disable=disable_cache) for i, block in enumerate(self.blocks): vid, txt, vid_shape, txt_shape = block( diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 3a0f8cfed..40c592a2b 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -3,10 +3,9 @@ from typing import Literal, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.models.attention_processor import Attention from einops import rearrange -from model import safe_pad_operation +from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution from comfy.ldm.modules.attention import optimized_attention @@ -216,67 +215,37 @@ class Attention(nn.Module): return hidden_states -def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution weight matrix to a 3D one. - Parameters: - weight_2d: The weight matrix of 2D conv to be inflated. - weight_3d: The weight matrix of 3D conv to be initialized. - inflation_mode: the mode of inflation - """ - assert inflation_mode in ["tail", "replicate"] - assert weight_3d.shape[:2] == weight_2d.shape[:2] +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor): with torch.no_grad(): - if inflation_mode == "replicate": - depth = weight_3d.size(2) - weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) - else: - weight_3d.fill_(0.0) - weight_3d[:, :, -1].copy_(weight_2d) + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) return weight_3d - -def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution bias tensor to a 3D one - Parameters: - bias_2d: The bias tensor of 2D conv to be inflated. - bias_3d: The bias tensor of 3D conv to be initialized. - inflation_mode: Placeholder to align `inflate_weight`. - """ - assert bias_3d.shape == bias_2d.shape +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor): with torch.no_grad(): bias_3d.copy_(bias_2d) return bias_3d def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): - """ - the main function to inflated 2D parameters to 3D. - """ weight_name = prefix + "weight" bias_name = prefix + "bias" if weight_name in state_dict: weight_2d = state_dict[weight_name] if weight_2d.dim() == 4: - # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) weight_3d = inflate_weight_fn( weight_2d=weight_2d, weight_3d=layer.weight, - inflation_mode=layer.inflation_mode, ) state_dict[weight_name] = weight_3d else: return state_dict - # It's a 3d state dict, should not do inflation on both bias and weight. if bias_name in state_dict: bias_2d = state_dict[bias_name] if bias_2d.dim() == 1: - # Assuming the 2D biases are 1D tensors (out_channels,) bias_3d = inflate_bias_fn( bias_2d=bias_2d, bias_3d=layer.bias, - inflation_mode=layer.inflation_mode, ) state_dict[bias_name] = bias_3d return state_dict @@ -384,19 +353,12 @@ class InflatedCausalConv3d(nn.Conv3d): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - if self.inflation_mode != "none": - state_dict = modify_state_dict( - self, - state_dict, - prefix, - inflate_weight_fn=inflate_weight, - inflate_bias_fn=inflate_bias, - ) + super()._load_from_state_dict( state_dict, prefix, local_metadata, - (strict and self.inflation_mode == "none"), + strict, missing_keys, unexpected_keys, error_msgs, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 600c089fa..804878432 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -344,12 +344,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b dit_config = {} + dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 2560 dit_config["heads"] = 20 dit_config["num_layers"] = 32 dit_config["norm_eps"] = 1.0e-05 dit_config["qk_rope"] = None dit_config["mlp_type"] = "swiglu" + dit_config["vid_out_norm"] = True return dit_config diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2301b1188..4162a1f5e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1154,20 +1154,21 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) -class SeedVR2(supported_models_base.Base): +class SeedVR2(supported_models_base.BASE): unet_config = { - "image_mode": "seedvr2" + "image_model": "seedvr2" } latent_format = comfy.latent_formats.SeedVR2 vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] supported_inference_dtypes = [torch.bfloat16, torch.float32] def get_model(self, state_dict, prefix = "", device=None): out = model_base.SeedVR2(self, device=device) return out def clip_target(self, state_dict={}): - return None + return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel) class ACEStep(supported_models_base.BASE): unet_config = { diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 9d4e8bf34..e2fa10427 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -4,6 +4,7 @@ import torch import math from einops import rearrange +import torch.nn.functional as F from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode @@ -108,12 +109,13 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Int.Input("resolution_width") ], outputs = [ - io.Image.Output("images") + io.Image.Output("processed_images") ] ) @classmethod def execute(cls, images, resolution_height, resolution_width): + images = images.permute(0, 3, 1, 2) max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) @@ -134,7 +136,7 @@ class SeedVR2Conditioning(io.ComfyNode): inputs=[ io.Conditioning.Input("text_positive_conditioning"), io.Conditioning.Input("text_negative_conditioning"), - io.Conditioning.Input("vae_conditioning") + io.Latent.Input("vae_conditioning") ], outputs=[io.Conditioning.Output(display_name = "positive"), io.Conditioning.Output(display_name = "negative"), @@ -143,7 +145,8 @@ class SeedVR2Conditioning(io.ComfyNode): @classmethod def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: - # TODO: should do the flattening logic as with the original code + + vae_conditioning = vae_conditioning["samples"] pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] @@ -160,14 +163,18 @@ class SeedVR2Conditioning(io.ComfyNode): cond = inter(vae_conditioning, aug_noises, t) condition = get_conditions(noises, cond) - # TODO / FIXME - pos_cond = torch.cat([condition, pos_cond], dim = 0) - neg_cond = torch.cat([condition, neg_cond], dim = 0) + pos_shape = pos_cond.shape[1] + neg_shape = neg_shape.shape[1] + diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) + if pos_shape.shape[1] > neg_shape.shape[1]: + neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) + else: + pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - negative = [[pos_cond, {}]] - positive = [[neg_cond, {}]] + negative = [[pos_cond, {"condition": condition}]] + positive = [[neg_cond, {"condition": condition}]] - return io.NodeOutput(positive, negative, noises) + return io.NodeOutput(positive, negative, {"samples": noises}) class SeedVRExtension(ComfyExtension): @override diff --git a/nodes.py b/nodes.py index 1b465b9e6..72e9c6066 100644 --- a/nodes.py +++ b/nodes.py @@ -2283,7 +2283,8 @@ def init_builtin_extra_nodes(): "nodes_string.py", "nodes_camera_trajectory.py", "nodes_edit_model.py", - "nodes_tcfg.py" + "nodes_tcfg.py", + "nodes_seedvr.py" ] import_failed = []