From d9927cdebd227d3d19e5f12bfe28e4432a4fd736 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 24 Feb 2026 15:05:23 +0200 Subject: [PATCH] Some cleanup --- comfy/ldm/wan/model.py | 157 +++++------------------------------------ 1 file changed, 19 insertions(+), 138 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index d1ca9926e..f1ae2e896 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1623,96 +1623,12 @@ class HumoWanModel(WanModel): return x class SCAILWanModel(WanModel): - r""" - Wan diffusion backbone supporting both text-to-video and image-to-video. - """ + def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=16, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) - def __init__(self, - model_type='scail', - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=5120, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - eps=1e-6, - flf_pos_embed_token_number=None, - in_dim_ref_conv=None, - wan_attn_block_class=WanAttentionBlock, - image_model=None, - device=None, - dtype=None, - operations=None, - ): + self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations) - - self.dtype = dtype - operation_settings = {"operations": operations, "device": device, "dtype": dtype} - - self.model_type = model_type - - self.patch_size = patch_size - self.text_len = text_len - self.in_dim = in_dim - self.dim = dim - self.ffn_dim = ffn_dim - self.freq_dim = freq_dim - self.text_dim = text_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - - # embeddings - self.patch_embedding = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32) - self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32) - - self.text_embedding = nn.Sequential( - operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'), - operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) - - self.time_embedding = nn.Sequential( - operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) - self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) - - # blocks - cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' - self.blocks = nn.ModuleList([ - wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) - for i in range(num_layers) - ]) - - # head - self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings) - - d = dim // num_heads - self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) - - self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) - - def forward_orig( - self, - x, - t, - context, - clip_fea=None, - freqs=None, - transformer_options={}, - pose_latents=None, - **kwargs, - ): + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, **kwargs): reference_latent = kwargs.get("reference_latent", None) if reference_latent is not None: @@ -1778,62 +1694,27 @@ class SCAILWanModel(WanModel): return x def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, transformer_options={}): - patch_size = self.patch_size - t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) - h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) - w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) - if steps_t is None: - steps_t = t_len - if steps_h is None: - steps_h = h_len - if steps_w is None: - steps_w = w_len + if pose_latents is None: + return main_freqs - h_start = 0 - w_start = 0 - rope_options = transformer_options.get("rope_options", None) - if rope_options is not None: - t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 - h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 - w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + downscale = H_pose != h + # when using pose downscaling, encode at the actual resolution so the freq space covers the right range, then pool back down below + pose_H_virtual = H_pose * 2 if downscale else H_pose + pose_W_virtual = W_pose * 2 if downscale else W_pose - t_start += rope_options.get("shift_t", 0.0) - h_start += rope_options.get("shift_y", 0.0) - w_start += rope_options.get("shift_x", 0.0) + pose_transformer_options = {"rope_options": {"shift_x": 120.0}} # pose frames use a fixed w-offset of 120 to spatially separate them from the main frames + pose_freqs = super().rope_encode(F_pose, pose_H_virtual, pose_W_virtual, t_start=t_start, device=device, dtype=dtype, transformer_options=pose_transformer_options) - img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) - img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) - img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + freqs = torch.cat([main_freqs, pose_freqs], dim=1) - segments = [img_ids] # Start with main frames - - # Pose frames position IDs - if pose_latents is not None: - pose_frame_shape = pose_latents.shape - F_pose, H_pose, W_pose = pose_frame_shape[-3], pose_frame_shape[-2], pose_frame_shape[-1] - - downscale = H_pose != h + # downsample pose frequencies to match actual pose input resolution + if downscale: pose_f_len_full = ((F_pose + (self.patch_size[0] // 2)) // self.patch_size[0]) - pose_h_len_full = (((H_pose * (2 if downscale else 1)) + (self.patch_size[1] // 2)) // self.patch_size[1]) # 2x height - pose_w_len_full = (((W_pose * (2 if downscale else 1)) + (self.patch_size[2] // 2)) // self.patch_size[2]) # 2x width - - pose_img_ids = torch.zeros((pose_f_len_full, pose_h_len_full, pose_w_len_full, 3), device=device, dtype=dtype) - global_h_offset, global_w_offset = 0, 120 # global spatial offset to separate pose from main frames spatially (SCAIL uses 120 as offset) - pose_img_ids[:, :, :, 0] = pose_img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (pose_f_len_full - 1), steps=pose_f_len_full, device=device, dtype=dtype).reshape(-1, 1, 1) - pose_img_ids[:, :, :, 1] = pose_img_ids[:, :, :, 1] + torch.linspace(global_h_offset, global_h_offset + pose_h_len_full - 1, steps=pose_h_len_full, device=device, dtype=dtype).reshape(1, -1, 1) - pose_img_ids[:, :, :, 2] = pose_img_ids[:, :, :, 2] + torch.linspace(global_w_offset, global_w_offset + pose_w_len_full - 1, steps=pose_w_len_full, device=device, dtype=dtype).reshape(1, 1, -1) - - segments.append(pose_img_ids.reshape(1, -1, pose_img_ids.shape[-1])) - - combined_img_ids = torch.cat(segments, dim=1) - - freqs = self.rope_embedder(combined_img_ids).movedim(1, 2) - - # Downsample pose frequencies to match actual pose input resolution - if pose_latents is not None and downscale: + pose_h_len_full = (((H_pose * 2) + (self.patch_size[1] // 2)) // self.patch_size[1]) + pose_w_len_full = (((W_pose * 2) + (self.patch_size[2] // 2)) // self.patch_size[2]) pose_h_len_actual = ((H_pose + (self.patch_size[1] // 2)) // self.patch_size[1]) pose_w_len_actual = ((W_pose + (self.patch_size[2] // 2)) // self.patch_size[2])