mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-13 01:30:32 +08:00
Some cleanup
This commit is contained in:
parent
224c06bf1f
commit
d9927cdebd
@ -1623,96 +1623,12 @@ class HumoWanModel(WanModel):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class SCAILWanModel(WanModel):
|
class SCAILWanModel(WanModel):
|
||||||
r"""
|
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=16, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||||
model_type='scail',
|
|
||||||
patch_size=(1, 2, 2),
|
|
||||||
text_len=512,
|
|
||||||
in_dim=16,
|
|
||||||
dim=5120,
|
|
||||||
ffn_dim=8192,
|
|
||||||
freq_dim=256,
|
|
||||||
text_dim=4096,
|
|
||||||
out_dim=16,
|
|
||||||
num_heads=16,
|
|
||||||
num_layers=32,
|
|
||||||
window_size=(-1, -1),
|
|
||||||
qk_norm=True,
|
|
||||||
cross_attn_norm=True,
|
|
||||||
eps=1e-6,
|
|
||||||
flf_pos_embed_token_number=None,
|
|
||||||
in_dim_ref_conv=None,
|
|
||||||
wan_attn_block_class=WanAttentionBlock,
|
|
||||||
image_model=None,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
operations=None,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||||
|
|
||||||
self.dtype = dtype
|
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
|
||||||
|
|
||||||
self.model_type = model_type
|
|
||||||
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.text_len = text_len
|
|
||||||
self.in_dim = in_dim
|
|
||||||
self.dim = dim
|
|
||||||
self.ffn_dim = ffn_dim
|
|
||||||
self.freq_dim = freq_dim
|
|
||||||
self.text_dim = text_dim
|
|
||||||
self.out_dim = out_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.window_size = window_size
|
|
||||||
self.qk_norm = qk_norm
|
|
||||||
self.cross_attn_norm = cross_attn_norm
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
# embeddings
|
|
||||||
self.patch_embedding = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
|
|
||||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
|
|
||||||
|
|
||||||
self.text_embedding = nn.Sequential(
|
|
||||||
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
|
|
||||||
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
|
||||||
|
|
||||||
self.time_embedding = nn.Sequential(
|
|
||||||
operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
|
||||||
self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
|
||||||
|
|
||||||
# blocks
|
|
||||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
|
||||||
self.blocks = nn.ModuleList([
|
|
||||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
|
||||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
|
||||||
for i in range(num_layers)
|
|
||||||
])
|
|
||||||
|
|
||||||
# head
|
|
||||||
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
|
|
||||||
|
|
||||||
d = dim // num_heads
|
|
||||||
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
|
|
||||||
|
|
||||||
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
context,
|
|
||||||
clip_fea=None,
|
|
||||||
freqs=None,
|
|
||||||
transformer_options={},
|
|
||||||
pose_latents=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
|
|
||||||
reference_latent = kwargs.get("reference_latent", None)
|
reference_latent = kwargs.get("reference_latent", None)
|
||||||
if reference_latent is not None:
|
if reference_latent is not None:
|
||||||
@ -1778,62 +1694,27 @@ class SCAILWanModel(WanModel):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, transformer_options={}):
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, transformer_options={}):
|
||||||
patch_size = self.patch_size
|
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
|
||||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
|
||||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
|
||||||
|
|
||||||
if steps_t is None:
|
if pose_latents is None:
|
||||||
steps_t = t_len
|
return main_freqs
|
||||||
if steps_h is None:
|
|
||||||
steps_h = h_len
|
|
||||||
if steps_w is None:
|
|
||||||
steps_w = w_len
|
|
||||||
|
|
||||||
h_start = 0
|
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||||
w_start = 0
|
downscale = H_pose != h
|
||||||
rope_options = transformer_options.get("rope_options", None)
|
# when using pose downscaling, encode at the actual resolution so the freq space covers the right range, then pool back down below
|
||||||
if rope_options is not None:
|
pose_H_virtual = H_pose * 2 if downscale else H_pose
|
||||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
pose_W_virtual = W_pose * 2 if downscale else W_pose
|
||||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
|
||||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
|
||||||
|
|
||||||
t_start += rope_options.get("shift_t", 0.0)
|
pose_transformer_options = {"rope_options": {"shift_x": 120.0}} # pose frames use a fixed w-offset of 120 to spatially separate them from the main frames
|
||||||
h_start += rope_options.get("shift_y", 0.0)
|
pose_freqs = super().rope_encode(F_pose, pose_H_virtual, pose_W_virtual, t_start=t_start, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||||
w_start += rope_options.get("shift_x", 0.0)
|
|
||||||
|
|
||||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
freqs = torch.cat([main_freqs, pose_freqs], dim=1)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
|
||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
|
||||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
|
||||||
|
|
||||||
segments = [img_ids] # Start with main frames
|
# downsample pose frequencies to match actual pose input resolution
|
||||||
|
if downscale:
|
||||||
# Pose frames position IDs
|
|
||||||
if pose_latents is not None:
|
|
||||||
pose_frame_shape = pose_latents.shape
|
|
||||||
F_pose, H_pose, W_pose = pose_frame_shape[-3], pose_frame_shape[-2], pose_frame_shape[-1]
|
|
||||||
|
|
||||||
downscale = H_pose != h
|
|
||||||
pose_f_len_full = ((F_pose + (self.patch_size[0] // 2)) // self.patch_size[0])
|
pose_f_len_full = ((F_pose + (self.patch_size[0] // 2)) // self.patch_size[0])
|
||||||
pose_h_len_full = (((H_pose * (2 if downscale else 1)) + (self.patch_size[1] // 2)) // self.patch_size[1]) # 2x height
|
pose_h_len_full = (((H_pose * 2) + (self.patch_size[1] // 2)) // self.patch_size[1])
|
||||||
pose_w_len_full = (((W_pose * (2 if downscale else 1)) + (self.patch_size[2] // 2)) // self.patch_size[2]) # 2x width
|
pose_w_len_full = (((W_pose * 2) + (self.patch_size[2] // 2)) // self.patch_size[2])
|
||||||
|
|
||||||
pose_img_ids = torch.zeros((pose_f_len_full, pose_h_len_full, pose_w_len_full, 3), device=device, dtype=dtype)
|
|
||||||
global_h_offset, global_w_offset = 0, 120 # global spatial offset to separate pose from main frames spatially (SCAIL uses 120 as offset)
|
|
||||||
pose_img_ids[:, :, :, 0] = pose_img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (pose_f_len_full - 1), steps=pose_f_len_full, device=device, dtype=dtype).reshape(-1, 1, 1)
|
|
||||||
pose_img_ids[:, :, :, 1] = pose_img_ids[:, :, :, 1] + torch.linspace(global_h_offset, global_h_offset + pose_h_len_full - 1, steps=pose_h_len_full, device=device, dtype=dtype).reshape(1, -1, 1)
|
|
||||||
pose_img_ids[:, :, :, 2] = pose_img_ids[:, :, :, 2] + torch.linspace(global_w_offset, global_w_offset + pose_w_len_full - 1, steps=pose_w_len_full, device=device, dtype=dtype).reshape(1, 1, -1)
|
|
||||||
|
|
||||||
segments.append(pose_img_ids.reshape(1, -1, pose_img_ids.shape[-1]))
|
|
||||||
|
|
||||||
combined_img_ids = torch.cat(segments, dim=1)
|
|
||||||
|
|
||||||
freqs = self.rope_embedder(combined_img_ids).movedim(1, 2)
|
|
||||||
|
|
||||||
# Downsample pose frequencies to match actual pose input resolution
|
|
||||||
if pose_latents is not None and downscale:
|
|
||||||
pose_h_len_actual = ((H_pose + (self.patch_size[1] // 2)) // self.patch_size[1])
|
pose_h_len_actual = ((H_pose + (self.patch_size[1] // 2)) // self.patch_size[1])
|
||||||
pose_w_len_actual = ((W_pose + (self.patch_size[2] // 2)) // self.patch_size[2])
|
pose_w_len_actual = ((W_pose + (self.patch_size[2] // 2)) // self.patch_size[2])
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user