lora, 7b model, cfg

This commit is contained in:
Yousef Rafat 2025-12-13 19:48:57 +02:00
parent 768c9cedf8
commit 58e7cea796
4 changed files with 50 additions and 18 deletions

View File

@ -1331,15 +1331,14 @@ class NaDiT(nn.Module):
**kwargs
):
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
conditions = kwargs.get("condition")
pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0)
pos_cond, neg_cond = context.chunk(2, dim=0)
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
pos_cond, txt_shape = flatten([pos_cond])
neg_cond, _ = flatten([neg_cond])
txt = torch.cat([pos_cond, neg_cond], dim = 0)
txt, txt_shape = flatten([pos_cond, neg_cond])
vid = x
vid, vid_shape = flatten(x)
cond_latent, _ = flatten(conditions)
@ -1360,6 +1359,28 @@ class NaDiT(nn.Module):
cache = Cache(disable=disable_cache)
for i, block in enumerate(self.blocks):
if ("block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block(
vid=args["vid"],
txt=args["txt"],
vid_shape=args["vid_shape"],
txt_shape=args["txt_shape"],
emb=args["emb"],
cache=args["cache"],
)
return out
out = blocks_replace[("block", i)]({
"vid":vid,
"txt":txt,
"vid_shape":vid_shape,
"txt_shape":txt_shape,
"emb":emb,
"cache":cache,
}, {"original_block": block_wrap})
vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"]
else:
vid, txt, vid_shape, txt_shape = block(
vid=vid,
txt=txt,
@ -1383,4 +1404,4 @@ class NaDiT(nn.Module):
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
vid = unflatten(vid, vid_shape)
return vid[0]
return torch.stack(vid)

View File

@ -342,6 +342,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
return dit_config
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
@ -352,7 +362,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["qk_rope"] = None
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1

View File

@ -1163,6 +1163,9 @@ class SeedVR2(supported_models_base.BASE):
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def get_model(self, state_dict, prefix = "", device=None):
out = model_base.SeedVR2(self, device=device)

View File

@ -198,9 +198,8 @@ class SeedVR2Conditioning(io.ComfyNode):
else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0)
negative = [[cond, {"condition": condition}]]
positive = [[cond, {"condition": condition}]]
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
return io.NodeOutput(positive, negative, {"samples": noises})