mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 04:50:49 +08:00
lora, 7b model, cfg
This commit is contained in:
parent
768c9cedf8
commit
58e7cea796
@ -1331,15 +1331,14 @@ class NaDiT(nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
conditions = kwargs.get("condition")
|
||||
|
||||
pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0)
|
||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
||||
pos_cond, txt_shape = flatten([pos_cond])
|
||||
neg_cond, _ = flatten([neg_cond])
|
||||
txt = torch.cat([pos_cond, neg_cond], dim = 0)
|
||||
txt, txt_shape = flatten([pos_cond, neg_cond])
|
||||
|
||||
vid = x
|
||||
vid, vid_shape = flatten(x)
|
||||
cond_latent, _ = flatten(conditions)
|
||||
|
||||
@ -1360,14 +1359,36 @@ class NaDiT(nn.Module):
|
||||
|
||||
cache = Cache(disable=disable_cache)
|
||||
for i, block in enumerate(self.blocks):
|
||||
vid, txt, vid_shape, txt_shape = block(
|
||||
vid=vid,
|
||||
txt=txt,
|
||||
vid_shape=vid_shape,
|
||||
txt_shape=txt_shape,
|
||||
emb=emb,
|
||||
cache=cache,
|
||||
)
|
||||
if ("block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block(
|
||||
vid=args["vid"],
|
||||
txt=args["txt"],
|
||||
vid_shape=args["vid_shape"],
|
||||
txt_shape=args["txt_shape"],
|
||||
emb=args["emb"],
|
||||
cache=args["cache"],
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("block", i)]({
|
||||
"vid":vid,
|
||||
"txt":txt,
|
||||
"vid_shape":vid_shape,
|
||||
"txt_shape":txt_shape,
|
||||
"emb":emb,
|
||||
"cache":cache,
|
||||
}, {"original_block": block_wrap})
|
||||
vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"]
|
||||
else:
|
||||
vid, txt, vid_shape, txt_shape = block(
|
||||
vid=vid,
|
||||
txt=txt,
|
||||
vid_shape=vid_shape,
|
||||
txt_shape=txt_shape,
|
||||
emb=emb,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
if self.vid_out_norm:
|
||||
vid = self.vid_out_norm(vid)
|
||||
@ -1383,4 +1404,4 @@ class NaDiT(nn.Module):
|
||||
|
||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
||||
vid = unflatten(vid, vid_shape)
|
||||
return vid[0]
|
||||
return torch.stack(vid)
|
||||
|
||||
@ -342,6 +342,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
return dit_config
|
||||
|
||||
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["mlp_type"] = "normal"
|
||||
return dit_config
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
@ -352,7 +362,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["qk_rope"] = None
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
dit_config["vid_out_norm"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
|
||||
@ -1163,6 +1163,9 @@ class SeedVR2(supported_models_base.BASE):
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix = "", device=None):
|
||||
out = model_base.SeedVR2(self, device=device)
|
||||
|
||||
@ -198,9 +198,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
else:
|
||||
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
|
||||
|
||||
cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0)
|
||||
negative = [[cond, {"condition": condition}]]
|
||||
positive = [[cond, {"condition": condition}]]
|
||||
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": noises})
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user