mostly fixing mistakes

This commit is contained in:
Yousef Rafat 2025-12-09 00:16:17 +02:00
parent 44a5bf353a
commit f030b3afc8
6 changed files with 49 additions and 72 deletions

View File

@ -1187,8 +1187,12 @@ class NaDiT(nn.Module):
rope_dim = 128, rope_dim = 128,
rope_type = "mmrope3d", rope_type = "mmrope3d",
vid_out_norm: Optional[str] = None, vid_out_norm: Optional[str] = None,
device = None,
dtype = None,
operations = None,
**kwargs, **kwargs,
): ):
self.dtype = dtype
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
txt_dim = vid_dim txt_dim = vid_dim
emb_dim = vid_dim * 6 emb_dim = vid_dim * 6
@ -1292,33 +1296,33 @@ class NaDiT(nn.Module):
x, x,
timestep, timestep,
context, # l c context, # l c
txt_shape, # b 1
disable_cache: bool = True, # for test # TODO ? disable_cache: bool = True, # for test # TODO ?
**kwargs
): ):
transformer_options = kwargs.get("transformer_options", {})
c_or_u_list = transformer_options.get("cond_or_uncond", [])
cond_latent = c_or_u_list[0]["condition"]
pos_cond, neg_cond = context.chunk(2, dim=0) pos_cond, neg_cond = context.chunk(2, dim=0)
pos_cond, pos_shape = flatten(pos_cond) # txt_shape should be the same for both
neg_cond, neg_shape = flatten(neg_cond) pos_cond, txt_shape = flatten(pos_cond)
diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) neg_cond, _ = flatten(neg_cond)
if pos_shape.shape[1] > neg_shape.shape[1]: txt = torch.cat([pos_cond, neg_cond], dim = 0)
neg_shape = F.pad(neg_shape, (0, 0, 0, diff)) txt_shape[0] *= 2
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
else:
pos_shape = F.pad(pos_shape, (0, 0, 0, diff))
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
vid = x vid = x
txt = context
vid, vid_shape = flatten(x) vid, vid_shape = flatten(x)
vid = torch.cat([cond_latent, vid])
if txt_shape.size(-1) == 1 and self.need_txt_repeat: if txt_shape.size(-1) == 1 and self.need_txt_repeat:
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
# slice vid after patching in when using sequence parallelism
txt = self.txt_in(txt) txt = self.txt_in(txt)
vid, vid_shape = self.vid_in(vid, vid_shape) vid, vid_shape = self.vid_in(vid, vid_shape)
# Embedding input.
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
# Body
cache = Cache(disable=disable_cache) cache = Cache(disable=disable_cache)
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
vid, txt, vid_shape, txt_shape = block( vid, txt, vid_shape, txt_shape = block(

View File

@ -3,10 +3,9 @@ from typing import Literal, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from einops import rearrange from einops import rearrange
from model import safe_pad_operation from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@ -216,67 +215,37 @@ class Attention(nn.Module):
return hidden_states return hidden_states
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor):
"""
Inflate a 2D convolution weight matrix to a 3D one.
Parameters:
weight_2d: The weight matrix of 2D conv to be inflated.
weight_3d: The weight matrix of 3D conv to be initialized.
inflation_mode: the mode of inflation
"""
assert inflation_mode in ["tail", "replicate"]
assert weight_3d.shape[:2] == weight_2d.shape[:2]
with torch.no_grad(): with torch.no_grad():
if inflation_mode == "replicate": depth = weight_3d.size(2)
depth = weight_3d.size(2) weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
else:
weight_3d.fill_(0.0)
weight_3d[:, :, -1].copy_(weight_2d)
return weight_3d return weight_3d
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor):
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str):
"""
Inflate a 2D convolution bias tensor to a 3D one
Parameters:
bias_2d: The bias tensor of 2D conv to be inflated.
bias_3d: The bias tensor of 3D conv to be initialized.
inflation_mode: Placeholder to align `inflate_weight`.
"""
assert bias_3d.shape == bias_2d.shape
with torch.no_grad(): with torch.no_grad():
bias_3d.copy_(bias_2d) bias_3d.copy_(bias_2d)
return bias_3d return bias_3d
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
"""
the main function to inflated 2D parameters to 3D.
"""
weight_name = prefix + "weight" weight_name = prefix + "weight"
bias_name = prefix + "bias" bias_name = prefix + "bias"
if weight_name in state_dict: if weight_name in state_dict:
weight_2d = state_dict[weight_name] weight_2d = state_dict[weight_name]
if weight_2d.dim() == 4: if weight_2d.dim() == 4:
# Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
weight_3d = inflate_weight_fn( weight_3d = inflate_weight_fn(
weight_2d=weight_2d, weight_2d=weight_2d,
weight_3d=layer.weight, weight_3d=layer.weight,
inflation_mode=layer.inflation_mode,
) )
state_dict[weight_name] = weight_3d state_dict[weight_name] = weight_3d
else: else:
return state_dict return state_dict
# It's a 3d state dict, should not do inflation on both bias and weight.
if bias_name in state_dict: if bias_name in state_dict:
bias_2d = state_dict[bias_name] bias_2d = state_dict[bias_name]
if bias_2d.dim() == 1: if bias_2d.dim() == 1:
# Assuming the 2D biases are 1D tensors (out_channels,)
bias_3d = inflate_bias_fn( bias_3d = inflate_bias_fn(
bias_2d=bias_2d, bias_2d=bias_2d,
bias_3d=layer.bias, bias_3d=layer.bias,
inflation_mode=layer.inflation_mode,
) )
state_dict[bias_name] = bias_3d state_dict[bias_name] = bias_3d
return state_dict return state_dict
@ -384,19 +353,12 @@ class InflatedCausalConv3d(nn.Conv3d):
def _load_from_state_dict( def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
): ):
if self.inflation_mode != "none":
state_dict = modify_state_dict(
self,
state_dict,
prefix,
inflate_weight_fn=inflate_weight,
inflate_bias_fn=inflate_bias,
)
super()._load_from_state_dict( super()._load_from_state_dict(
state_dict, state_dict,
prefix, prefix,
local_metadata, local_metadata,
(strict and self.inflation_mode == "none"), strict,
missing_keys, missing_keys,
unexpected_keys, unexpected_keys,
error_msgs, error_msgs,

View File

@ -344,12 +344,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {} dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560 dit_config["vid_dim"] = 2560
dit_config["heads"] = 20 dit_config["heads"] = 20
dit_config["num_layers"] = 32 dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05 dit_config["norm_eps"] = 1.0e-05
dit_config["qk_rope"] = None dit_config["qk_rope"] = None
dit_config["mlp_type"] = "swiglu" dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config return dit_config

View File

@ -1154,20 +1154,21 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.Base): class SeedVR2(supported_models_base.BASE):
unet_config = { unet_config = {
"image_mode": "seedvr2" "image_model": "seedvr2"
} }
latent_format = comfy.latent_formats.SeedVR2 latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix = "", device=None): def get_model(self, state_dict, prefix = "", device=None):
out = model_base.SeedVR2(self, device=device) out = model_base.SeedVR2(self, device=device)
return out return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return None return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel)
class ACEStep(supported_models_base.BASE): class ACEStep(supported_models_base.BASE):
unet_config = { unet_config = {

View File

@ -4,6 +4,7 @@ import torch
import math import math
from einops import rearrange from einops import rearrange
import torch.nn.functional as F
from torchvision.transforms import functional as TVF from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda, Normalize from torchvision.transforms import Lambda, Normalize
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
@ -108,12 +109,13 @@ class SeedVR2InputProcessing(io.ComfyNode):
io.Int.Input("resolution_width") io.Int.Input("resolution_width")
], ],
outputs = [ outputs = [
io.Image.Output("images") io.Image.Output("processed_images")
] ]
) )
@classmethod @classmethod
def execute(cls, images, resolution_height, resolution_width): def execute(cls, images, resolution_height, resolution_width):
images = images.permute(0, 3, 1, 2)
max_area = ((resolution_height * resolution_width)** 0.5) ** 2 max_area = ((resolution_height * resolution_width)** 0.5) ** 2
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
normalize = Normalize(0.5, 0.5) normalize = Normalize(0.5, 0.5)
@ -134,7 +136,7 @@ class SeedVR2Conditioning(io.ComfyNode):
inputs=[ inputs=[
io.Conditioning.Input("text_positive_conditioning"), io.Conditioning.Input("text_positive_conditioning"),
io.Conditioning.Input("text_negative_conditioning"), io.Conditioning.Input("text_negative_conditioning"),
io.Conditioning.Input("vae_conditioning") io.Latent.Input("vae_conditioning")
], ],
outputs=[io.Conditioning.Output(display_name = "positive"), outputs=[io.Conditioning.Output(display_name = "positive"),
io.Conditioning.Output(display_name = "negative"), io.Conditioning.Output(display_name = "negative"),
@ -143,7 +145,8 @@ class SeedVR2Conditioning(io.ComfyNode):
@classmethod @classmethod
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
# TODO: should do the flattening logic as with the original code
vae_conditioning = vae_conditioning["samples"]
pos_cond = text_positive_conditioning[0][0] pos_cond = text_positive_conditioning[0][0]
neg_cond = text_negative_conditioning[0][0] neg_cond = text_negative_conditioning[0][0]
@ -160,14 +163,18 @@ class SeedVR2Conditioning(io.ComfyNode):
cond = inter(vae_conditioning, aug_noises, t) cond = inter(vae_conditioning, aug_noises, t)
condition = get_conditions(noises, cond) condition = get_conditions(noises, cond)
# TODO / FIXME pos_shape = pos_cond.shape[1]
pos_cond = torch.cat([condition, pos_cond], dim = 0) neg_shape = neg_shape.shape[1]
neg_cond = torch.cat([condition, neg_cond], dim = 0) diff = abs(pos_shape.shape[1] - neg_shape.shape[1])
if pos_shape.shape[1] > neg_shape.shape[1]:
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
negative = [[pos_cond, {}]] negative = [[pos_cond, {"condition": condition}]]
positive = [[neg_cond, {}]] positive = [[neg_cond, {"condition": condition}]]
return io.NodeOutput(positive, negative, noises) return io.NodeOutput(positive, negative, {"samples": noises})
class SeedVRExtension(ComfyExtension): class SeedVRExtension(ComfyExtension):
@override @override

View File

@ -2283,7 +2283,8 @@ def init_builtin_extra_nodes():
"nodes_string.py", "nodes_string.py",
"nodes_camera_trajectory.py", "nodes_camera_trajectory.py",
"nodes_edit_model.py", "nodes_edit_model.py",
"nodes_tcfg.py" "nodes_tcfg.py",
"nodes_seedvr.py"
] ]
import_failed = [] import_failed = []