diff --git a/comfy_extras/nodes/nodes_wanmove.py b/comfy_extras/nodes/nodes_wanmove.py index 5f39afa46..8d1bd8323 100644 --- a/comfy_extras/nodes/nodes_wanmove.py +++ b/comfy_extras/nodes/nodes_wanmove.py @@ -1,5 +1,5 @@ import nodes -import node_helpers +from comfy import node_helpers import torch import torchvision.transforms.functional as TF import comfy.model_management @@ -7,20 +7,21 @@ import comfy.utils import numpy as np from typing_extensions import override from comfy_api.latest import ComfyExtension, io -from comfy_extras.nodes_wan import parse_json_tracks +from comfy_extras.nodes.nodes_wan import parse_json_tracks # https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py from PIL import Image, ImageDraw SKIP_ZERO = False + def get_pos_emb( - pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings. - pos_emb_dim: int, - theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions. - device: torch.device = torch.device("cpu"), - dtype: torch.dtype = torch.float32, -) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim) + pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings. + pos_emb_dim: int, + theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), # Function to compute thetas based on position and embedding dimensions. + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim) assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even." pos_k = pos_k.to(device, dtype) @@ -44,20 +45,21 @@ def get_pos_emb( return pos_emb + def create_pos_embeddings( - pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2] - pred_visibility: torch.Tensor, # the predicted visibility [T, N] - downsample_ratios: list[int], # the ratios for downsampling time, height, and width - height: int, # the height of the feature map - width: int, # the width of the feature map - track_num: int = -1, # the number of tracks to use - t_down_strategy: str = "sample", # the strategy for downsampling time dimension + pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2] + pred_visibility: torch.Tensor, # the predicted visibility [T, N] + downsample_ratios: list[int], # the ratios for downsampling time, height, and width + height: int, # the height of the feature map + width: int, # the width of the feature map + track_num: int = -1, # the number of tracks to use + t_down_strategy: str = "sample", # the strategy for downsampling time dimension ): assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension." t, n, _ = pred_tracks.shape t_down, h_down, w_down = downsample_ratios - track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long) + track_pos = - torch.ones(n, (t - 1) // t_down + 1, 2, dtype=torch.long) if track_num == -1: track_num = n @@ -68,11 +70,11 @@ def create_pos_embeddings( for t_idx in range(0, t, t_down): if t_down_strategy == "sample" or t_idx == 0: - cur_tracks = tracks[t_idx] # [N, 2] - cur_visibility = visibility[t_idx] # [N] + cur_tracks = tracks[t_idx] # [N, 2] + cur_visibility = visibility[t_idx] # [N] else: - cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0) - cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0) + cur_tracks = tracks[t_idx:t_idx + t_down].mean(dim=0) + cur_visibility = torch.any(visibility[t_idx:t_idx + t_down], dim=0) for i in range(track_num): if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height: @@ -81,12 +83,13 @@ def create_pos_embeddings( x, y = int(x // w_down), int(y // h_down) track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x - return track_pos # the position embeddings, [N, T', 2], 2 = height, width + return track_pos # the position embeddings, [N, T', 2], 2 = height, width + def replace_feature( - vae_feature: torch.Tensor, # [B, C', T', H', W'] - track_pos: torch.Tensor, # [B, N, T', 2] - strength: float = 1.0 + vae_feature: torch.Tensor, # [B, C', T', H', W'] + track_pos: torch.Tensor, # [B, N, T', 2] + strength: float = 1.0 ) -> torch.Tensor: b, _, t, h, w = vae_feature.shape assert b == track_pos.shape[0], "Batch size mismatch." @@ -126,9 +129,9 @@ def replace_feature( vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength - return vae_feature + # Visualize functions def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0): @@ -172,8 +175,8 @@ def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, def add_weighted(rgb, track): - rgb = np.array(rgb) # [H, W, C] "RGB" - track = np.array(track) # [H, W, C] "RGBA" + rgb = np.array(rgb) # [H, W, C] "RGB" + track = np.array(track) # [H, W, C] "RGBA" alpha = track[:, :, 3] / 255.0 alpha = np.stack([alpha] * 3, axis=-1) @@ -181,6 +184,7 @@ def add_weighted(rgb, track): return Image.fromarray(blend_img.astype(np.uint8)) + def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16): color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)] @@ -213,8 +217,8 @@ def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_ circle_color = color + (alpha_opacity,) draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size), - fill=circle_color - ) + fill=circle_color + ) # Store polyline data for batch processing tracks_coord = tracks[max(t - track_frame, 0):t + 1, n] @@ -296,15 +300,15 @@ class WanMoveTracksFromCoords(io.ComfyNode): @classmethod def execute(cls, track_coords, track_mask=None) -> io.NodeOutput: - device=comfy.model_management.intermediate_device() + device = comfy.model_management.intermediate_device() tracks_data = parse_json_tracks(track_coords) track_length = len(tracks_data[0]) track_list = [ - [[track[frame]['x'], track[frame]['y']] for track in tracks_data] - for frame in range(len(tracks_data[0])) - ] + [[track[frame]['x'], track[frame]['y']] for track in tracks_data] + for frame in range(len(tracks_data[0])) + ] tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2] num_tracks = tracks.shape[-2] @@ -365,10 +369,10 @@ class GenerateTracks(io.ComfyNode): end_x_px = end_x * width end_y_px = end_y * height - track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional + track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional t = torch.linspace(0, 1, num_frames, device=device) - if interpolation == "constant": # All points stay at start position + if interpolation == "constant": # All points stay at start position interp_values = torch.zeros_like(t) elif interpolation == "linear": interp_values = t @@ -379,14 +383,14 @@ class GenerateTracks(io.ComfyNode): elif interpolation == "ease_in_out": interp_values = t * t * (3 - 2 * t) - if bezier: # apply interpolation to t for timing control along the bezier path + if bezier: # apply interpolation to t for timing control along the bezier path t_interp = interp_values one_minus_t = 1 - t_interp x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px) tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px) - else: # calculate base x and y positions for each frame (center track) + else: # calculate base x and y positions for each frame (center track) x_positions = start_x_px + (end_x_px - start_x_px) * interp_values y_positions = start_y_px + (end_y_px - start_y_px) * interp_values # For non-bezier, tangent is constant (direction from start to end) @@ -400,15 +404,15 @@ class GenerateTracks(io.ComfyNode): ty = tangent_y[frame_idx].item() length = (tx ** 2 + ty ** 2) ** 0.5 - if length > 0: # Perpendicular unit vector (rotate 90 degrees) + if length > 0: # Perpendicular unit vector (rotate 90 degrees) perp_x = -ty / length perp_y = tx / length - else: # If tangent is zero, spread horizontally + else: # If tangent is zero, spread horizontally perp_x = 1.0 perp_y = 0.0 frame_tracks = [] - for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2 + for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2 offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px track_x = x_positions[frame_idx].item() + perp_x * offset track_y = y_positions[frame_idx].item() + perp_y * offset @@ -485,7 +489,7 @@ class WanMoveTrackToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: - device=comfy.model_management.intermediate_device() + device = comfy.model_management.intermediate_device() latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -531,5 +535,6 @@ class WanMoveExtension(ComfyExtension): GenerateTracks, ] + async def comfy_entrypoint() -> WanMoveExtension: return WanMoveExtension()