Merge branch 'comfyanonymous:master' into models-dir-arg

This commit is contained in:
Silver 2025-12-13 07:19:39 +01:00 committed by GitHub
commit 6ee2edde95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 771 additions and 46 deletions

View File

@ -53,6 +53,16 @@ try:
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
except:
print("Could not stash, cleaning index and trying again.") # noqa: T201
repo.state_cleanup()
repo.index.read_tree(repo.head.peel().tree)
repo.index.write()
try:
repo.stash(ident)
except KeyError:
print("nothing to stash.") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:

View File

@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
def is_visible_file(entry: os.DirEntry) -> bool:
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
return entry.is_file() and not entry.name.startswith('.')
sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()),
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)

View File

@ -259,8 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

View File

@ -497,15 +497,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self._has_bias = bias
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@ -530,7 +529,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
@ -560,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
requires_grad=False
)
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)

View File

@ -549,8 +549,10 @@ class VAE:
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:

View File

@ -541,7 +541,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.2
memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."]
@ -965,7 +965,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
@ -1026,7 +1026,7 @@ class ZImage(Lumina2):
"shift": 3.0,
}
memory_usage_factor = 1.7
memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@ -1289,7 +1289,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.038
memory_usage_factor = 0.044
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
@ -1332,7 +1332,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
memory_usage_factor = 1.95 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
@ -1397,7 +1397,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7
memory_usage_factor = 8.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1488,7 +1488,7 @@ class Kandinsky5(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.1 #TODO
memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1517,7 +1517,7 @@ class Kandinsky5Image(Kandinsky5):
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.1 #TODO
memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)

View File

@ -774,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="TRACKS")
class Tracks(ComfyTypeIO):
class TrackDict(TypedDict):
track_path: torch.Tensor
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -1815,7 +1822,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"]
if "expand" in data:
expand = data["expand"]
return cls(args=args, ui=ui, expand=expand)
return cls(*args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any:
return self.args[index]
@ -1894,6 +1901,7 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
"Tracks",
# Dynamic Types
"MatchType",
# "DynamicCombo",

View File

@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class OmniTaskStatusResponseData(BaseModel):
class TaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: OmniTaskStatusResults | None = Field(None)
task_result: TaskStatusResults | None = Field(None)
class OmniTaskStatusResponse(BaseModel):
class TaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: OmniTaskStatusResponseData | None = Field(None)
data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel):
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")

View File

@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import (
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
OmniTaskStatusResponse,
TaskStatusResponse,
TextToVideoWithAudioRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -242,7 +244,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -250,7 +252,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
@ -483,12 +485,12 @@ async def execute_image2video(
task_id = task_creation_response.data.task_id
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@ -834,7 +836,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -929,7 +931,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
@ -997,7 +999,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1081,7 +1083,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1162,7 +1164,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1237,7 +1239,7 @@ class OmniProImageNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
@ -1253,7 +1255,7 @@ class OmniProImageNode(IO.ComfyNode):
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
@ -1328,9 +1330,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image to Video",
display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling",
description="Kling Image to Video Node",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -2034,6 +2035,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images))
class TextToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
mode: str,
aspect_ratio: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
response_model=TaskStatusResponse,
data=TextToVideoWithAudioRequest(
model_name=model_name,
prompt=prompt,
mode=mode,
aspect_ratio=aspect_ratio,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class ImageToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
start_frame: Input.Image,
prompt: str,
mode: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
response_model=TaskStatusResponse,
data=ImageToVideoWithAudioRequest(
model_name=model_name,
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
prompt=prompt,
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2056,7 +2187,9 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
# OmniProImageNode, # need support from backend
OmniProImageNode,
TextToVideoWithAudio,
ImageToVideoWithAudio,
]

View File

@ -0,0 +1,535 @@
import nodes
import node_helpers
import torch
import torchvision.transforms.functional as TF
import comfy.model_management
import comfy.utils
import numpy as np
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_wan import parse_json_tracks
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
from PIL import Image, ImageDraw
SKIP_ZERO = False
def get_pos_emb(
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
pos_emb_dim: int,
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
pos_k = pos_k.to(device, dtype)
if SKIP_ZERO:
pos_k = pos_k + 1
batch_size = pos_k.size(0)
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
# Expand denominator to match the shape needed for broadcasting
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
thetas = theta_func(denominator_expanded, pos_emb_dim)
# Ensure pos_k is in the correct shape for broadcasting
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
# Concatenate sine and cosine embeddings along the last dimension
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
return pos_emb
def create_pos_embeddings(
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
height: int, # the height of the feature map
width: int, # the width of the feature map
track_num: int = -1, # the number of tracks to use
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
):
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
t, n, _ = pred_tracks.shape
t_down, h_down, w_down = downsample_ratios
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
if track_num == -1:
track_num = n
tracks_idx = torch.randperm(n)[:track_num]
tracks = pred_tracks[:, tracks_idx]
visibility = pred_visibility[:, tracks_idx]
for t_idx in range(0, t, t_down):
if t_down_strategy == "sample" or t_idx == 0:
cur_tracks = tracks[t_idx] # [N, 2]
cur_visibility = visibility[t_idx] # [N]
else:
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
for i in range(track_num):
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
continue
x, y = cur_tracks[i]
x, y = int(x // w_down), int(y // h_down)
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
def replace_feature(
vae_feature: torch.Tensor, # [B, C', T', H', W']
track_pos: torch.Tensor, # [B, N, T', 2]
strength: float = 1.0
) -> torch.Tensor:
b, _, t, h, w = vae_feature.shape
assert b == track_pos.shape[0], "Batch size mismatch."
n = track_pos.shape[1]
# Shuffle the trajectory order
track_pos = track_pos[:, torch.randperm(n), :, :]
# Extract coordinates at time steps ≥ 1 and generate a valid mask
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
# Get all valid indices
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
num_valid = valid_indices.shape[0]
if num_valid == 0:
return vae_feature
# Decompose valid indices into each dimension
batch_idx = valid_indices[:, 0]
track_idx = valid_indices[:, 1]
t_rel = valid_indices[:, 2]
t_target = t_rel + 1 # Convert to original time step indices
# Extract target position coordinates
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
# Extract source position coordinates (t=0)
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
# Get source features and assign to target positions
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
return vae_feature
# Visualize functions
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
draw = ImageDraw.Draw(overlay, 'RGBA')
points = points[::-1]
# Compute total length
total_length = 0
segment_lengths = []
for i in range(len(points) - 1):
dx = points[i + 1][0] - points[i][0]
dy = points[i + 1][1] - points[i][1]
length = (dx * dx + dy * dy) ** 0.5
segment_lengths.append(length)
total_length += length
if total_length == 0:
return
accumulated_length = 0
# Draw the gradient polyline
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
segment_length = segment_lengths[idx]
steps = max(int(segment_length), 1)
for i in range(steps):
current_length = accumulated_length + (i / steps) * segment_length
ratio = current_length / total_length
alpha = int(255 * (1 - ratio) * opacity)
color = (*start_color, alpha)
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
accumulated_length += segment_length
def add_weighted(rgb, track):
rgb = np.array(rgb) # [H, W, C] "RGB"
track = np.array(track) # [H, W, C] "RGBA"
alpha = track[:, :, 3] / 255.0
alpha = np.stack([alpha] * 3, axis=-1)
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
return Image.fromarray(blend_img.astype(np.uint8))
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
tracks = tracks[0].long().detach().cpu().numpy()
if visibility is not None:
visibility = visibility[0].detach().cpu().numpy()
num_frames, height, width = video.shape[:3]
num_tracks = tracks.shape[1]
alpha_opacity = int(255 * opacity)
output_frames = []
for t in range(num_frames):
frame_rgb = video[t].astype(np.float32)
# Create a single RGBA overlay for all tracks in this frame
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
draw_overlay = ImageDraw.Draw(overlay)
polyline_data = []
# Draw all circles on a single overlay
for n in range(num_tracks):
if visibility is not None and visibility[t, n] == 0:
continue
track_coord = tracks[t, n]
color = color_map[n % len(color_map)]
circle_color = color + (alpha_opacity,)
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
fill=circle_color
)
# Store polyline data for batch processing
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
if len(tracks_coord) > 1:
polyline_data.append((tracks_coord, color))
# Blend circles overlay once
overlay_np = np.array(overlay)
alpha = overlay_np[:, :, 3:4] / 255.0
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
# Draw all polylines on a single overlay
if polyline_data:
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
for tracks_coord, color in polyline_data:
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
# Blend polylines overlay once
polyline_np = np.array(polyline_overlay)
alpha = polyline_np[:, :, 3:4] / 255.0
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
return output_frames
class WanMoveVisualizeTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveVisualizeTracks",
category="conditioning/video_models",
inputs=[
io.Image.Input("images"),
io.Tracks.Input("tracks", optional=True),
io.Int.Input("line_resolution", default=24, min=1, max=1024),
io.Int.Input("circle_size", default=12, min=1, max=128),
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
io.Int.Input("line_width", default=16, min=1, max=128),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
if tracks is None:
return io.NodeOutput(images)
track_path = tracks["track_path"].unsqueeze(0)
track_visibility = tracks["track_visibility"].unsqueeze(0)
images_in = images * 255.0
if images_in.shape[0] != track_path.shape[1]:
repeat_count = track_path.shape[1] // images.shape[0]
images_in = images_in.repeat(repeat_count, 1, 1, 1)
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
class WanMoveTracksFromCoords(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTracksFromCoords",
category="conditioning/video_models",
inputs=[
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
io.Mask.Input("track_mask", optional=True),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
tracks_data = parse_json_tracks(track_coords)
track_length = len(tracks_data[0])
track_list = [
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
for frame in range(len(tracks_data[0]))
]
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
num_tracks = tracks.shape[-2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class GenerateTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GenerateTracks",
category="conditioning/video_models",
inputs=[
io.Int.Input("width", default=832, min=16, max=4096, step=16),
io.Int.Input("height", default=480, min=16, max=4096, step=16),
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
io.Int.Input("num_frames", default=81, min=1, max=1024),
io.Int.Input("num_tracks", default=5, min=1, max=100),
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Combo.Input(
"interpolation",
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
tooltip="Controls the timing/speed of movement along the path.",
),
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
device = comfy.model_management.intermediate_device()
track_length = num_frames
# normalized coordinates to pixel coordinates
start_x_px = start_x * width
start_y_px = start_y * height
mid_x_px = mid_x * width
mid_y_px = mid_y * height
end_x_px = end_x * width
end_y_px = end_y * height
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
t = torch.linspace(0, 1, num_frames, device=device)
if interpolation == "constant": # All points stay at start position
interp_values = torch.zeros_like(t)
elif interpolation == "linear":
interp_values = t
elif interpolation == "ease_in":
interp_values = t ** 2
elif interpolation == "ease_out":
interp_values = 1 - (1 - t) ** 2
elif interpolation == "ease_in_out":
interp_values = t * t * (3 - 2 * t)
if bezier: # apply interpolation to t for timing control along the bezier path
t_interp = interp_values
one_minus_t = 1 - t_interp
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
else: # calculate base x and y positions for each frame (center track)
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
# For non-bezier, tangent is constant (direction from start to end)
tangent_x = torch.full_like(t, end_x_px - start_x_px)
tangent_y = torch.full_like(t, end_y_px - start_y_px)
track_list = []
for frame_idx in range(num_frames):
# Calculate perpendicular direction at this frame
tx = tangent_x[frame_idx].item()
ty = tangent_y[frame_idx].item()
length = (tx ** 2 + ty ** 2) ** 0.5
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
perp_x = -ty / length
perp_y = tx / length
else: # If tangent is zero, spread horizontally
perp_x = 1.0
perp_y = 0.0
frame_tracks = []
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
track_x = x_positions[frame_idx].item() + perp_x * offset
track_y = y_positions[frame_idx].item() + perp_y * offset
frame_tracks.append([track_x, track_y])
track_list.append(frame_tracks)
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class WanMoveConcatTrack(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveConcatTrack",
category="conditioning/video_models",
inputs=[
io.Tracks.Input("tracks_1"),
io.Tracks.Input("tracks_2", optional=True),
],
outputs=[
io.Tracks.Output(),
],
)
@classmethod
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
if tracks_2 is None:
return io.NodeOutput(tracks_1)
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
out_track_info = {}
out_track_info["track_path"] = tracks_out
out_track_info["track_visibility"] = mask_out
return io.NodeOutput(out_track_info)
class WanMoveTrackToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Tracks.Input("tracks", optional=True),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image"),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if tracks is not None and strength > 0.0:
tracks_path = tracks["track_path"][:length] # [T, N, 2]
num_tracks = tracks_path.shape[-2]
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
else:
concat_latent_image_pos = concat_latent_image
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanMoveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanMoveTrackToVideo,
WanMoveTracksFromCoords,
WanMoveConcatTrack,
WanMoveVisualizeTracks,
GenerateTracks,
]
async def comfy_entrypoint() -> WanMoveExtension:
return WanMoveExtension()

View File

@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes():
"nodes_logic.py",
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
]
import_failed = []

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.33.13
comfyui-frontend-package==1.34.8
comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1
torch