convert WanCameraEmbedding node to V3 schema (#9714)

This commit is contained in:
Alexander Piskun 2025-09-13 00:38:12 +03:00 committed by GitHub
parent 45bc1f5c00
commit f9d2e4b742
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,12 +2,12 @@ import nodes
import torch import torch
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
from typing_extensions import override
import comfy.model_management import comfy.model_management
from comfy_api.latest import ComfyExtension, io
MAX_RESOLUTION = nodes.MAX_RESOLUTION
CAMERA_DICT = { CAMERA_DICT = {
"base_T_norm": 1.5, "base_T_norm": 1.5,
"base_angle": np.pi/3, "base_angle": np.pi/3,
@ -148,32 +148,47 @@ def get_camera_motion(angle, T, speed, n=81):
RT = np.stack(RT) RT = np.stack(RT)
return RT return RT
class WanCameraEmbedding: class WanCameraEmbedding(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return io.Schema(
"required": { node_id="WanCameraEmbedding",
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}), category="camera",
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), io.Combo.Input(
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}), "camera_pose",
}, options=[
"optional":{ "Static",
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}), "Pan Up",
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), "Pan Down",
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}), "Pan Left",
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), "Pan Right",
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}), "Zoom In",
} "Zoom Out",
"Anti Clockwise (ACW)",
"ClockWise (CW)",
],
default="Static",
),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True),
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True),
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True),
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True),
],
outputs=[
io.WanCameraEmbedding.Output(display_name="camera_embedding"),
io.Int.Output(display_name="width"),
io.Int.Output(display_name="height"),
io.Int.Output(display_name="length"),
],
)
} @classmethod
def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
RETURN_NAMES = ("camera_embedding","width","height","length")
FUNCTION = "run"
CATEGORY = "camera"
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
""" """
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
@ -210,9 +225,15 @@ class WanCameraEmbedding:
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
return (control_camera_video, width, height, length) return io.NodeOutput(control_camera_video, width, height, length)
NODE_CLASS_MAPPINGS = { class CameraTrajectoryExtension(ComfyExtension):
"WanCameraEmbedding": WanCameraEmbedding, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanCameraEmbedding,
]
async def comfy_entrypoint() -> CameraTrajectoryExtension:
return CameraTrajectoryExtension()