mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
lora, 7b model, cfg
This commit is contained in:
parent
768c9cedf8
commit
58e7cea796
@ -1331,15 +1331,14 @@ class NaDiT(nn.Module):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
conditions = kwargs.get("condition")
|
conditions = kwargs.get("condition")
|
||||||
|
|
||||||
pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0)
|
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||||
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
||||||
pos_cond, txt_shape = flatten([pos_cond])
|
txt, txt_shape = flatten([pos_cond, neg_cond])
|
||||||
neg_cond, _ = flatten([neg_cond])
|
|
||||||
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
|
||||||
|
|
||||||
vid = x
|
|
||||||
vid, vid_shape = flatten(x)
|
vid, vid_shape = flatten(x)
|
||||||
cond_latent, _ = flatten(conditions)
|
cond_latent, _ = flatten(conditions)
|
||||||
|
|
||||||
@ -1360,14 +1359,36 @@ class NaDiT(nn.Module):
|
|||||||
|
|
||||||
cache = Cache(disable=disable_cache)
|
cache = Cache(disable=disable_cache)
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
vid, txt, vid_shape, txt_shape = block(
|
if ("block", i) in blocks_replace:
|
||||||
vid=vid,
|
def block_wrap(args):
|
||||||
txt=txt,
|
out = {}
|
||||||
vid_shape=vid_shape,
|
out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block(
|
||||||
txt_shape=txt_shape,
|
vid=args["vid"],
|
||||||
emb=emb,
|
txt=args["txt"],
|
||||||
cache=cache,
|
vid_shape=args["vid_shape"],
|
||||||
)
|
txt_shape=args["txt_shape"],
|
||||||
|
emb=args["emb"],
|
||||||
|
cache=args["cache"],
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("block", i)]({
|
||||||
|
"vid":vid,
|
||||||
|
"txt":txt,
|
||||||
|
"vid_shape":vid_shape,
|
||||||
|
"txt_shape":txt_shape,
|
||||||
|
"emb":emb,
|
||||||
|
"cache":cache,
|
||||||
|
}, {"original_block": block_wrap})
|
||||||
|
vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"]
|
||||||
|
else:
|
||||||
|
vid, txt, vid_shape, txt_shape = block(
|
||||||
|
vid=vid,
|
||||||
|
txt=txt,
|
||||||
|
vid_shape=vid_shape,
|
||||||
|
txt_shape=txt_shape,
|
||||||
|
emb=emb,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
|
||||||
if self.vid_out_norm:
|
if self.vid_out_norm:
|
||||||
vid = self.vid_out_norm(vid)
|
vid = self.vid_out_norm(vid)
|
||||||
@ -1383,4 +1404,4 @@ class NaDiT(nn.Module):
|
|||||||
|
|
||||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
||||||
vid = unflatten(vid, vid_shape)
|
vid = unflatten(vid, vid_shape)
|
||||||
return vid[0]
|
return torch.stack(vid)
|
||||||
|
|||||||
@ -342,6 +342,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["axes_lens"] = [300, 512, 512]
|
dit_config["axes_lens"] = [300, 512, 512]
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 3072
|
||||||
|
dit_config["heads"] = 24
|
||||||
|
dit_config["num_layers"] = 36
|
||||||
|
dit_config["norm_eps"] = 1e-5
|
||||||
|
dit_config["qk_rope"] = True
|
||||||
|
dit_config["mlp_type"] = "normal"
|
||||||
|
return dit_config
|
||||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "seedvr2"
|
dit_config["image_model"] = "seedvr2"
|
||||||
@ -352,7 +362,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["qk_rope"] = None
|
dit_config["qk_rope"] = None
|
||||||
dit_config["mlp_type"] = "swiglu"
|
dit_config["mlp_type"] = "swiglu"
|
||||||
dit_config["vid_out_norm"] = True
|
dit_config["vid_out_norm"] = True
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
|
|||||||
@ -1163,6 +1163,9 @@ class SeedVR2(supported_models_base.BASE):
|
|||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix = "", device=None):
|
def get_model(self, state_dict, prefix = "", device=None):
|
||||||
out = model_base.SeedVR2(self, device=device)
|
out = model_base.SeedVR2(self, device=device)
|
||||||
|
|||||||
@ -198,9 +198,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||||
|
|
||||||
cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0)
|
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
negative = [[cond, {"condition": condition}]]
|
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
positive = [[cond, {"condition": condition}]]
|
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user