fix imports

This commit is contained in:
doctorpangloss 2025-12-12 12:23:28 -08:00
parent 9d6d09608a
commit ff66341112

View File

@ -1,5 +1,5 @@
import nodes import nodes
import node_helpers from comfy import node_helpers
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
import comfy.model_management import comfy.model_management
@ -7,13 +7,14 @@ import comfy.utils
import numpy as np import numpy as np
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_wan import parse_json_tracks from comfy_extras.nodes.nodes_wan import parse_json_tracks
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py # https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
SKIP_ZERO = False SKIP_ZERO = False
def get_pos_emb( def get_pos_emb(
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings. pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
pos_emb_dim: int, pos_emb_dim: int,
@ -44,6 +45,7 @@ def get_pos_emb(
return pos_emb return pos_emb
def create_pos_embeddings( def create_pos_embeddings(
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2] pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
pred_visibility: torch.Tensor, # the predicted visibility [T, N] pred_visibility: torch.Tensor, # the predicted visibility [T, N]
@ -83,6 +85,7 @@ def create_pos_embeddings(
return track_pos # the position embeddings, [N, T', 2], 2 = height, width return track_pos # the position embeddings, [N, T', 2], 2 = height, width
def replace_feature( def replace_feature(
vae_feature: torch.Tensor, # [B, C', T', H', W'] vae_feature: torch.Tensor, # [B, C', T', H', W']
track_pos: torch.Tensor, # [B, N, T', 2] track_pos: torch.Tensor, # [B, N, T', 2]
@ -126,9 +129,9 @@ def replace_feature(
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
return vae_feature return vae_feature
# Visualize functions # Visualize functions
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0): def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
@ -181,6 +184,7 @@ def add_weighted(rgb, track):
return Image.fromarray(blend_img.astype(np.uint8)) return Image.fromarray(blend_img.astype(np.uint8))
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16): def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)] color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
@ -531,5 +535,6 @@ class WanMoveExtension(ComfyExtension):
GenerateTracks, GenerateTracks,
] ]
async def comfy_entrypoint() -> WanMoveExtension: async def comfy_entrypoint() -> WanMoveExtension:
return WanMoveExtension() return WanMoveExtension()