mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
mostly fixing mistakes
This commit is contained in:
parent
44a5bf353a
commit
f030b3afc8
@ -1187,8 +1187,12 @@ class NaDiT(nn.Module):
|
|||||||
rope_dim = 128,
|
rope_dim = 128,
|
||||||
rope_type = "mmrope3d",
|
rope_type = "mmrope3d",
|
||||||
vid_out_norm: Optional[str] = None,
|
vid_out_norm: Optional[str] = None,
|
||||||
|
device = None,
|
||||||
|
dtype = None,
|
||||||
|
operations = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
self.dtype = dtype
|
||||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||||
txt_dim = vid_dim
|
txt_dim = vid_dim
|
||||||
emb_dim = vid_dim * 6
|
emb_dim = vid_dim * 6
|
||||||
@ -1292,33 +1296,33 @@ class NaDiT(nn.Module):
|
|||||||
x,
|
x,
|
||||||
timestep,
|
timestep,
|
||||||
context, # l c
|
context, # l c
|
||||||
txt_shape, # b 1
|
|
||||||
disable_cache: bool = True, # for test # TODO ?
|
disable_cache: bool = True, # for test # TODO ?
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
c_or_u_list = transformer_options.get("cond_or_uncond", [])
|
||||||
|
cond_latent = c_or_u_list[0]["condition"]
|
||||||
|
|
||||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||||
pos_cond, pos_shape = flatten(pos_cond)
|
# txt_shape should be the same for both
|
||||||
neg_cond, neg_shape = flatten(neg_cond)
|
pos_cond, txt_shape = flatten(pos_cond)
|
||||||
diff = abs(pos_shape.shape[1] - neg_shape.shape[1])
|
neg_cond, _ = flatten(neg_cond)
|
||||||
if pos_shape.shape[1] > neg_shape.shape[1]:
|
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
||||||
neg_shape = F.pad(neg_shape, (0, 0, 0, diff))
|
txt_shape[0] *= 2
|
||||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
|
||||||
else:
|
|
||||||
pos_shape = F.pad(pos_shape, (0, 0, 0, diff))
|
|
||||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
|
||||||
vid = x
|
vid = x
|
||||||
txt = context
|
|
||||||
vid, vid_shape = flatten(x)
|
vid, vid_shape = flatten(x)
|
||||||
|
|
||||||
|
vid = torch.cat([cond_latent, vid])
|
||||||
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
||||||
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
||||||
# slice vid after patching in when using sequence parallelism
|
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||||
|
|
||||||
# Embedding input.
|
|
||||||
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
|
||||||
|
|
||||||
# Body
|
|
||||||
cache = Cache(disable=disable_cache)
|
cache = Cache(disable=disable_cache)
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
vid, txt, vid_shape, txt_shape = block(
|
vid, txt, vid_shape, txt_shape = block(
|
||||||
|
|||||||
@ -3,10 +3,9 @@ from typing import Literal, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from diffusers.models.attention_processor import Attention
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from model import safe_pad_operation
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
@ -216,67 +215,37 @@ class Attention(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str):
|
def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor):
|
||||||
"""
|
|
||||||
Inflate a 2D convolution weight matrix to a 3D one.
|
|
||||||
Parameters:
|
|
||||||
weight_2d: The weight matrix of 2D conv to be inflated.
|
|
||||||
weight_3d: The weight matrix of 3D conv to be initialized.
|
|
||||||
inflation_mode: the mode of inflation
|
|
||||||
"""
|
|
||||||
assert inflation_mode in ["tail", "replicate"]
|
|
||||||
assert weight_3d.shape[:2] == weight_2d.shape[:2]
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if inflation_mode == "replicate":
|
|
||||||
depth = weight_3d.size(2)
|
depth = weight_3d.size(2)
|
||||||
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
|
weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth)
|
||||||
else:
|
|
||||||
weight_3d.fill_(0.0)
|
|
||||||
weight_3d[:, :, -1].copy_(weight_2d)
|
|
||||||
return weight_3d
|
return weight_3d
|
||||||
|
|
||||||
|
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor):
|
||||||
def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str):
|
|
||||||
"""
|
|
||||||
Inflate a 2D convolution bias tensor to a 3D one
|
|
||||||
Parameters:
|
|
||||||
bias_2d: The bias tensor of 2D conv to be inflated.
|
|
||||||
bias_3d: The bias tensor of 3D conv to be initialized.
|
|
||||||
inflation_mode: Placeholder to align `inflate_weight`.
|
|
||||||
"""
|
|
||||||
assert bias_3d.shape == bias_2d.shape
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
bias_3d.copy_(bias_2d)
|
bias_3d.copy_(bias_2d)
|
||||||
return bias_3d
|
return bias_3d
|
||||||
|
|
||||||
|
|
||||||
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
|
def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn):
|
||||||
"""
|
|
||||||
the main function to inflated 2D parameters to 3D.
|
|
||||||
"""
|
|
||||||
weight_name = prefix + "weight"
|
weight_name = prefix + "weight"
|
||||||
bias_name = prefix + "bias"
|
bias_name = prefix + "bias"
|
||||||
if weight_name in state_dict:
|
if weight_name in state_dict:
|
||||||
weight_2d = state_dict[weight_name]
|
weight_2d = state_dict[weight_name]
|
||||||
if weight_2d.dim() == 4:
|
if weight_2d.dim() == 4:
|
||||||
# Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w)
|
|
||||||
weight_3d = inflate_weight_fn(
|
weight_3d = inflate_weight_fn(
|
||||||
weight_2d=weight_2d,
|
weight_2d=weight_2d,
|
||||||
weight_3d=layer.weight,
|
weight_3d=layer.weight,
|
||||||
inflation_mode=layer.inflation_mode,
|
|
||||||
)
|
)
|
||||||
state_dict[weight_name] = weight_3d
|
state_dict[weight_name] = weight_3d
|
||||||
else:
|
else:
|
||||||
return state_dict
|
return state_dict
|
||||||
# It's a 3d state dict, should not do inflation on both bias and weight.
|
|
||||||
if bias_name in state_dict:
|
if bias_name in state_dict:
|
||||||
bias_2d = state_dict[bias_name]
|
bias_2d = state_dict[bias_name]
|
||||||
if bias_2d.dim() == 1:
|
if bias_2d.dim() == 1:
|
||||||
# Assuming the 2D biases are 1D tensors (out_channels,)
|
|
||||||
bias_3d = inflate_bias_fn(
|
bias_3d = inflate_bias_fn(
|
||||||
bias_2d=bias_2d,
|
bias_2d=bias_2d,
|
||||||
bias_3d=layer.bias,
|
bias_3d=layer.bias,
|
||||||
inflation_mode=layer.inflation_mode,
|
|
||||||
)
|
)
|
||||||
state_dict[bias_name] = bias_3d
|
state_dict[bias_name] = bias_3d
|
||||||
return state_dict
|
return state_dict
|
||||||
@ -384,19 +353,12 @@ class InflatedCausalConv3d(nn.Conv3d):
|
|||||||
def _load_from_state_dict(
|
def _load_from_state_dict(
|
||||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
):
|
):
|
||||||
if self.inflation_mode != "none":
|
|
||||||
state_dict = modify_state_dict(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
prefix,
|
|
||||||
inflate_weight_fn=inflate_weight,
|
|
||||||
inflate_bias_fn=inflate_bias,
|
|
||||||
)
|
|
||||||
super()._load_from_state_dict(
|
super()._load_from_state_dict(
|
||||||
state_dict,
|
state_dict,
|
||||||
prefix,
|
prefix,
|
||||||
local_metadata,
|
local_metadata,
|
||||||
(strict and self.inflation_mode == "none"),
|
strict,
|
||||||
missing_keys,
|
missing_keys,
|
||||||
unexpected_keys,
|
unexpected_keys,
|
||||||
error_msgs,
|
error_msgs,
|
||||||
|
|||||||
@ -344,12 +344,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
dit_config["vid_dim"] = 2560
|
dit_config["vid_dim"] = 2560
|
||||||
dit_config["heads"] = 20
|
dit_config["heads"] = 20
|
||||||
dit_config["num_layers"] = 32
|
dit_config["num_layers"] = 32
|
||||||
dit_config["norm_eps"] = 1.0e-05
|
dit_config["norm_eps"] = 1.0e-05
|
||||||
dit_config["qk_rope"] = None
|
dit_config["qk_rope"] = None
|
||||||
dit_config["mlp_type"] = "swiglu"
|
dit_config["mlp_type"] = "swiglu"
|
||||||
|
dit_config["vid_out_norm"] = True
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
|||||||
@ -1154,20 +1154,21 @@ class Chroma(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||||
|
|
||||||
class SeedVR2(supported_models_base.Base):
|
class SeedVR2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_mode": "seedvr2"
|
"image_model": "seedvr2"
|
||||||
}
|
}
|
||||||
latent_format = comfy.latent_formats.SeedVR2
|
latent_format = comfy.latent_formats.SeedVR2
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix = "", device=None):
|
def get_model(self, state_dict, prefix = "", device=None):
|
||||||
out = model_base.SeedVR2(self, device=device)
|
out = model_base.SeedVR2(self, device=device)
|
||||||
return out
|
return out
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return None
|
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel)
|
||||||
|
|
||||||
class ACEStep(supported_models_base.BASE):
|
class ACEStep(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import torch
|
|||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
from torchvision.transforms import functional as TVF
|
from torchvision.transforms import functional as TVF
|
||||||
from torchvision.transforms import Lambda, Normalize
|
from torchvision.transforms import Lambda, Normalize
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
@ -108,12 +109,13 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
io.Int.Input("resolution_width")
|
io.Int.Input("resolution_width")
|
||||||
],
|
],
|
||||||
outputs = [
|
outputs = [
|
||||||
io.Image.Output("images")
|
io.Image.Output("processed_images")
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, resolution_height, resolution_width):
|
def execute(cls, images, resolution_height, resolution_width):
|
||||||
|
images = images.permute(0, 3, 1, 2)
|
||||||
max_area = ((resolution_height * resolution_width)** 0.5) ** 2
|
max_area = ((resolution_height * resolution_width)** 0.5) ** 2
|
||||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||||
normalize = Normalize(0.5, 0.5)
|
normalize = Normalize(0.5, 0.5)
|
||||||
@ -134,7 +136,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("text_positive_conditioning"),
|
io.Conditioning.Input("text_positive_conditioning"),
|
||||||
io.Conditioning.Input("text_negative_conditioning"),
|
io.Conditioning.Input("text_negative_conditioning"),
|
||||||
io.Conditioning.Input("vae_conditioning")
|
io.Latent.Input("vae_conditioning")
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output(display_name = "positive"),
|
outputs=[io.Conditioning.Output(display_name = "positive"),
|
||||||
io.Conditioning.Output(display_name = "negative"),
|
io.Conditioning.Output(display_name = "negative"),
|
||||||
@ -143,7 +145,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
||||||
# TODO: should do the flattening logic as with the original code
|
|
||||||
|
vae_conditioning = vae_conditioning["samples"]
|
||||||
pos_cond = text_positive_conditioning[0][0]
|
pos_cond = text_positive_conditioning[0][0]
|
||||||
neg_cond = text_negative_conditioning[0][0]
|
neg_cond = text_negative_conditioning[0][0]
|
||||||
|
|
||||||
@ -160,14 +163,18 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
cond = inter(vae_conditioning, aug_noises, t)
|
cond = inter(vae_conditioning, aug_noises, t)
|
||||||
condition = get_conditions(noises, cond)
|
condition = get_conditions(noises, cond)
|
||||||
|
|
||||||
# TODO / FIXME
|
pos_shape = pos_cond.shape[1]
|
||||||
pos_cond = torch.cat([condition, pos_cond], dim = 0)
|
neg_shape = neg_shape.shape[1]
|
||||||
neg_cond = torch.cat([condition, neg_cond], dim = 0)
|
diff = abs(pos_shape.shape[1] - neg_shape.shape[1])
|
||||||
|
if pos_shape.shape[1] > neg_shape.shape[1]:
|
||||||
|
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||||
|
else:
|
||||||
|
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||||
|
|
||||||
negative = [[pos_cond, {}]]
|
negative = [[pos_cond, {"condition": condition}]]
|
||||||
positive = [[neg_cond, {}]]
|
positive = [[neg_cond, {"condition": condition}]]
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, noises)
|
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||||
|
|
||||||
class SeedVRExtension(ComfyExtension):
|
class SeedVRExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -2283,7 +2283,8 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_string.py",
|
"nodes_string.py",
|
||||||
"nodes_camera_trajectory.py",
|
"nodes_camera_trajectory.py",
|
||||||
"nodes_edit_model.py",
|
"nodes_edit_model.py",
|
||||||
"nodes_tcfg.py"
|
"nodes_tcfg.py",
|
||||||
|
"nodes_seedvr.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user