mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
mostly fixing mistakes
This commit is contained in:
parent
44a5bf353a
commit
f030b3afc8
@ -1187,8 +1187,12 @@ class NaDiT(nn.Module):
|
||||
rope_dim = 128,
|
||||
rope_type = "mmrope3d",
|
||||
vid_out_norm: Optional[str] = None,
|
||||
device = None,
|
||||
dtype = None,
|
||||
operations = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.dtype = dtype
|
||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||
txt_dim = vid_dim
|
||||
emb_dim = vid_dim * 6
|
||||
@ -1292,33 +1296,33 @@ class NaDiT(nn.Module):
|
||||
x,
|
||||
timestep,
|
||||
context, # l c
|
||||
txt_shape, # b 1
|
||||
disable_cache: bool = True, # for test # TODO ?
|
||||
**kwargs
|
||||
):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
c_or_u_list = transformer_options.get("cond_or_uncond", [])
|
||||
cond_latent = c_or_u_list[0]["condition"]
|
||||
|
||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||
pos_cond, pos_shape = flatten(pos_cond)
|
||||
neg_cond, neg_shape = flatten(neg_cond)
|
||||
diff = abs(pos_shape.shape[1] - neg_shape.shape[1])
|
||||
if pos_shape.shape[1] > neg_shape.shape[1]:
|
||||
neg_shape = F.pad(neg_shape, (0, 0, 0, diff))
|
||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||
else:
|
||||
pos_shape = F.pad(pos_shape, (0, 0, 0, diff))
|
||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||
# txt_shape should be the same for both
|
||||
pos_cond, txt_shape = flatten(pos_cond)
|
||||
neg_cond, _ = flatten(neg_cond)
|
||||
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
||||
txt_shape[0] *= 2
|
||||
|
||||
vid = x
|
||||
txt = context
|
||||
vid, vid_shape = flatten(x)
|
||||
|
||||
vid = torch.cat([cond_latent, vid])
|
||||
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
||||
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
||||
# slice vid after patching in when using sequence parallelism
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||
|
||||
# Embedding input.
|
||||
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
||||
|
||||
# Body
|
||||
cache = Cache(disable=disable_cache)
|
||||
for i, block in enumerate(self.blocks):
|
||||
vid, txt, vid_shape, txt_shape = block(
|
||||
|
||||
@ -3,10 +3,9 @@ from typing import Literal, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from einops import rearrange
|
||||
|
||||
from model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
@ -216,67 +215,37 @@ class Attention(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str):
|
||||
"""
|
||||
Inflate a 2D convolution weight matrix to a 3D one.
|
||||
Parameters:
|
||||
weight_2d: The weight matrix of 2D conv to be inflated.
|
||||
weight_3d: The weight matrix of 3D conv to be initialized.
|
||||
inflation_mode: the mode of inflation
|
||||
"""
|
||||
assert inflation_mode in ["tail", "replicate"]
|
||||
assert weight_3d.shape[:2] == weight_2d.shape[:2]
|
||||
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
if inflation_mode == "replicate":
|
||||
depth = weight_3d.size(2)
|
||||
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
|
||||
else:
|
||||
weight_3d.fill_(0.0)
|
||||
weight_3d[:, :, -1].copy_(weight_2d)
|
||||
return weight_3d
|
||||
|
||||
|
||||
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str):
|
||||
"""
|
||||
Inflate a 2D convolution bias tensor to a 3D one
|
||||
Parameters:
|
||||
bias_2d: The bias tensor of 2D conv to be inflated.
|
||||
bias_3d: The bias tensor of 3D conv to be initialized.
|
||||
inflation_mode: Placeholder to align `inflate_weight`.
|
||||
"""
|
||||
assert bias_3d.shape == bias_2d.shape
|
||||
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
bias_3d.copy_(bias_2d)
|
||||
return bias_3d
|
||||
|
||||
|
||||
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
|
||||
"""
|
||||
the main function to inflated 2D parameters to 3D.
|
||||
"""
|
||||
weight_name = prefix + "weight"
|
||||
bias_name = prefix + "bias"
|
||||
if weight_name in state_dict:
|
||||
weight_2d = state_dict[weight_name]
|
||||
if weight_2d.dim() == 4:
|
||||
# Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
|
||||
weight_3d = inflate_weight_fn(
|
||||
weight_2d=weight_2d,
|
||||
weight_3d=layer.weight,
|
||||
inflation_mode=layer.inflation_mode,
|
||||
)
|
||||
state_dict[weight_name] = weight_3d
|
||||
else:
|
||||
return state_dict
|
||||
# It's a 3d state dict, should not do inflation on both bias and weight.
|
||||
if bias_name in state_dict:
|
||||
bias_2d = state_dict[bias_name]
|
||||
if bias_2d.dim() == 1:
|
||||
# Assuming the 2D biases are 1D tensors (out_channels,)
|
||||
bias_3d = inflate_bias_fn(
|
||||
bias_2d=bias_2d,
|
||||
bias_3d=layer.bias,
|
||||
inflation_mode=layer.inflation_mode,
|
||||
)
|
||||
state_dict[bias_name] = bias_3d
|
||||
return state_dict
|
||||
@ -384,19 +353,12 @@ class InflatedCausalConv3d(nn.Conv3d):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
if self.inflation_mode != "none":
|
||||
state_dict = modify_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
inflate_weight_fn=inflate_weight,
|
||||
inflate_bias_fn=inflate_bias,
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
(strict and self.inflation_mode == "none"),
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
|
||||
@ -344,12 +344,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 2560
|
||||
dit_config["heads"] = 20
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["norm_eps"] = 1.0e-05
|
||||
dit_config["qk_rope"] = None
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
dit_config["vid_out_norm"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
|
||||
@ -1154,20 +1154,21 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class SeedVR2(supported_models_base.Base):
|
||||
class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_mode": "seedvr2"
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix = "", device=None):
|
||||
out = model_base.SeedVR2(self, device=device)
|
||||
return out
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel)
|
||||
|
||||
class ACEStep(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
|
||||
@ -4,6 +4,7 @@ import torch
|
||||
import math
|
||||
from einops import rearrange
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.transforms import Lambda, Normalize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
@ -108,12 +109,13 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
io.Int.Input("resolution_width")
|
||||
],
|
||||
outputs = [
|
||||
io.Image.Output("images")
|
||||
io.Image.Output("processed_images")
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, resolution_height, resolution_width):
|
||||
images = images.permute(0, 3, 1, 2)
|
||||
max_area = ((resolution_height * resolution_width)** 0.5) ** 2
|
||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||
normalize = Normalize(0.5, 0.5)
|
||||
@ -134,7 +136,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Conditioning.Input("text_positive_conditioning"),
|
||||
io.Conditioning.Input("text_negative_conditioning"),
|
||||
io.Conditioning.Input("vae_conditioning")
|
||||
io.Latent.Input("vae_conditioning")
|
||||
],
|
||||
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||
io.Conditioning.Output(display_name = "negative"),
|
||||
@ -143,7 +145,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
||||
# TODO: should do the flattening logic as with the original code
|
||||
|
||||
vae_conditioning = vae_conditioning["samples"]
|
||||
pos_cond = text_positive_conditioning[0][0]
|
||||
neg_cond = text_negative_conditioning[0][0]
|
||||
|
||||
@ -160,14 +163,18 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
cond = inter(vae_conditioning, aug_noises, t)
|
||||
condition = get_conditions(noises, cond)
|
||||
|
||||
# TODO / FIXME
|
||||
pos_cond = torch.cat([condition, pos_cond], dim = 0)
|
||||
neg_cond = torch.cat([condition, neg_cond], dim = 0)
|
||||
pos_shape = pos_cond.shape[1]
|
||||
neg_shape = neg_shape.shape[1]
|
||||
diff = abs(pos_shape.shape[1] - neg_shape.shape[1])
|
||||
if pos_shape.shape[1] > neg_shape.shape[1]:
|
||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||
else:
|
||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||
|
||||
negative = [[pos_cond, {}]]
|
||||
positive = [[neg_cond, {}]]
|
||||
negative = [[pos_cond, {"condition": condition}]]
|
||||
positive = [[neg_cond, {"condition": condition}]]
|
||||
|
||||
return io.NodeOutput(positive, negative, noises)
|
||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||
|
||||
class SeedVRExtension(ComfyExtension):
|
||||
@override
|
||||
|
||||
Loading…
Reference in New Issue
Block a user