From 58e7cea79635f78a05cfc9e2d936655e1a026520 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 13 Dec 2025 19:48:57 +0200 Subject: [PATCH] lora, 7b model, cfg --- comfy/ldm/seedvr/model.py | 49 +++++++++++++++++++++++++----------- comfy/model_detection.py | 11 +++++++- comfy/supported_models.py | 3 +++ comfy_extras/nodes_seedvr.py | 5 ++-- 4 files changed, 50 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index cbf1383d3..9b69c85a1 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1331,15 +1331,14 @@ class NaDiT(nn.Module): **kwargs ): transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") - pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0) + pos_cond, neg_cond = context.chunk(2, dim=0) pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) - pos_cond, txt_shape = flatten([pos_cond]) - neg_cond, _ = flatten([neg_cond]) - txt = torch.cat([pos_cond, neg_cond], dim = 0) + txt, txt_shape = flatten([pos_cond, neg_cond]) - vid = x vid, vid_shape = flatten(x) cond_latent, _ = flatten(conditions) @@ -1360,14 +1359,36 @@ class NaDiT(nn.Module): cache = Cache(disable=disable_cache) for i, block in enumerate(self.blocks): - vid, txt, vid_shape, txt_shape = block( - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) + if ("block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( + vid=args["vid"], + txt=args["txt"], + vid_shape=args["vid_shape"], + txt_shape=args["txt_shape"], + emb=args["emb"], + cache=args["cache"], + ) + return out + out = blocks_replace[("block", i)]({ + "vid":vid, + "txt":txt, + "vid_shape":vid_shape, + "txt_shape":txt_shape, + "emb":emb, + "cache":cache, + }, {"original_block": block_wrap}) + vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] + else: + vid, txt, vid_shape, txt_shape = block( + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) if self.vid_out_norm: vid = self.vid_out_norm(vid) @@ -1383,4 +1404,4 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - return vid[0] + return torch.stack(vid) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 804878432..22e774730 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -342,6 +342,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [300, 512, 512] return dit_config + elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + dit_config["norm_eps"] = 1e-5 + dit_config["qk_rope"] = True + dit_config["mlp_type"] = "normal" + return dit_config elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b dit_config = {} dit_config["image_model"] = "seedvr2" @@ -352,7 +362,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["qk_rope"] = None dit_config["mlp_type"] = "swiglu" dit_config["vid_out_norm"] = True - return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1cab38f97..a5f116327 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1163,6 +1163,9 @@ class SeedVR2(supported_models_base.BASE): vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] supported_inference_dtypes = [torch.bfloat16, torch.float32] + sampling_settings = { + "shift": 1.0, + } def get_model(self, state_dict, prefix = "", device=None): out = model_base.SeedVR2(self, device=device) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index eebcb7dc0..08009b4d9 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -198,9 +198,8 @@ class SeedVR2Conditioning(io.ComfyNode): else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0) - negative = [[cond, {"condition": condition}]] - positive = [[cond, {"condition": condition}]] + negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] + positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] return io.NodeOutput(positive, negative, {"samples": noises})