mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
fix imports
This commit is contained in:
parent
9d6d09608a
commit
ff66341112
@ -1,5 +1,5 @@
|
|||||||
import nodes
|
import nodes
|
||||||
import node_helpers
|
from comfy import node_helpers
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -7,13 +7,14 @@ import comfy.utils
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from comfy_extras.nodes_wan import parse_json_tracks
|
from comfy_extras.nodes.nodes_wan import parse_json_tracks
|
||||||
|
|
||||||
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
|
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
SKIP_ZERO = False
|
SKIP_ZERO = False
|
||||||
|
|
||||||
|
|
||||||
def get_pos_emb(
|
def get_pos_emb(
|
||||||
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
|
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
|
||||||
pos_emb_dim: int,
|
pos_emb_dim: int,
|
||||||
@ -44,6 +45,7 @@ def get_pos_emb(
|
|||||||
|
|
||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
|
|
||||||
def create_pos_embeddings(
|
def create_pos_embeddings(
|
||||||
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
|
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
|
||||||
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
|
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
|
||||||
@ -83,6 +85,7 @@ def create_pos_embeddings(
|
|||||||
|
|
||||||
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
|
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
|
||||||
|
|
||||||
|
|
||||||
def replace_feature(
|
def replace_feature(
|
||||||
vae_feature: torch.Tensor, # [B, C', T', H', W']
|
vae_feature: torch.Tensor, # [B, C', T', H', W']
|
||||||
track_pos: torch.Tensor, # [B, N, T', 2]
|
track_pos: torch.Tensor, # [B, N, T', 2]
|
||||||
@ -126,9 +129,9 @@ def replace_feature(
|
|||||||
|
|
||||||
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
|
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
|
||||||
|
|
||||||
|
|
||||||
return vae_feature
|
return vae_feature
|
||||||
|
|
||||||
|
|
||||||
# Visualize functions
|
# Visualize functions
|
||||||
|
|
||||||
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
|
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
|
||||||
@ -181,6 +184,7 @@ def add_weighted(rgb, track):
|
|||||||
|
|
||||||
return Image.fromarray(blend_img.astype(np.uint8))
|
return Image.fromarray(blend_img.astype(np.uint8))
|
||||||
|
|
||||||
|
|
||||||
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
|
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
|
||||||
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
|
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
|
||||||
|
|
||||||
@ -531,5 +535,6 @@ class WanMoveExtension(ComfyExtension):
|
|||||||
GenerateTracks,
|
GenerateTracks,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanMoveExtension:
|
async def comfy_entrypoint() -> WanMoveExtension:
|
||||||
return WanMoveExtension()
|
return WanMoveExtension()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user