mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Add strength parameter and a node to generate tracks
This commit is contained in:
parent
4c1f71ebf6
commit
cf3561fd28
@ -86,6 +86,7 @@ def create_pos_embeddings(
|
||||
def replace_feature(
|
||||
vae_feature: torch.Tensor, # [B, C', T', H', W']
|
||||
track_pos: torch.Tensor, # [B, N, T', 2]
|
||||
strength: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
b, _, t, h, w = vae_feature.shape
|
||||
assert b == track_pos.shape[0], "Batch size mismatch."
|
||||
@ -121,7 +122,10 @@ def replace_feature(
|
||||
|
||||
# Get source features and assign to target positions
|
||||
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
|
||||
vae_feature[batch_idx, :, t_target, h_target, w_target] = src_features
|
||||
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
|
||||
|
||||
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
|
||||
|
||||
|
||||
return vae_feature
|
||||
|
||||
@ -315,6 +319,115 @@ class WanMoveTracksFromCoords(io.ComfyNode):
|
||||
return io.NodeOutput(out_track_info, track_length)
|
||||
|
||||
|
||||
class GenerateTracks(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GenerateTracks",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=4096, step=16),
|
||||
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
|
||||
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
|
||||
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
|
||||
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
|
||||
io.Int.Input("num_frames", default=81, min=1, max=1024),
|
||||
io.Int.Input("num_tracks", default=5, min=1, max=100),
|
||||
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
|
||||
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
|
||||
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||
io.Combo.Input(
|
||||
"interpolation",
|
||||
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
|
||||
tooltip="Controls the timing/speed of movement along the path.",
|
||||
),
|
||||
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
|
||||
],
|
||||
outputs=[
|
||||
io.Tracks.Output(),
|
||||
io.Int.Output(display_name="track_length"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
|
||||
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
|
||||
device = comfy.model_management.intermediate_device()
|
||||
track_length = num_frames
|
||||
|
||||
# normalized coordinates to pixel coordinates
|
||||
start_x_px = start_x * width
|
||||
start_y_px = start_y * height
|
||||
mid_x_px = mid_x * width
|
||||
mid_y_px = mid_y * height
|
||||
end_x_px = end_x * width
|
||||
end_y_px = end_y * height
|
||||
|
||||
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
|
||||
|
||||
t = torch.linspace(0, 1, num_frames, device=device)
|
||||
if interpolation == "constant": # All points stay at start position
|
||||
interp_values = torch.zeros_like(t)
|
||||
elif interpolation == "linear":
|
||||
interp_values = t
|
||||
elif interpolation == "ease_in":
|
||||
interp_values = t ** 2
|
||||
elif interpolation == "ease_out":
|
||||
interp_values = 1 - (1 - t) ** 2
|
||||
elif interpolation == "ease_in_out":
|
||||
interp_values = t * t * (3 - 2 * t)
|
||||
|
||||
if bezier: # apply interpolation to t for timing control along the bezier path
|
||||
t_interp = interp_values
|
||||
one_minus_t = 1 - t_interp
|
||||
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
|
||||
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
|
||||
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
|
||||
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
|
||||
else: # calculate base x and y positions for each frame (center track)
|
||||
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
|
||||
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
|
||||
# For non-bezier, tangent is constant (direction from start to end)
|
||||
tangent_x = torch.full_like(t, end_x_px - start_x_px)
|
||||
tangent_y = torch.full_like(t, end_y_px - start_y_px)
|
||||
|
||||
track_list = []
|
||||
for frame_idx in range(num_frames):
|
||||
# Calculate perpendicular direction at this frame
|
||||
tx = tangent_x[frame_idx].item()
|
||||
ty = tangent_y[frame_idx].item()
|
||||
length = (tx ** 2 + ty ** 2) ** 0.5
|
||||
|
||||
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
|
||||
perp_x = -ty / length
|
||||
perp_y = tx / length
|
||||
else: # If tangent is zero, spread horizontally
|
||||
perp_x = 1.0
|
||||
perp_y = 0.0
|
||||
|
||||
frame_tracks = []
|
||||
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
|
||||
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
|
||||
track_x = x_positions[frame_idx].item() + perp_x * offset
|
||||
track_y = y_positions[frame_idx].item() + perp_y * offset
|
||||
frame_tracks.append([track_x, track_y])
|
||||
track_list.append(frame_tracks)
|
||||
|
||||
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
||||
|
||||
if track_mask is None:
|
||||
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
||||
else:
|
||||
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
||||
|
||||
out_track_info = {}
|
||||
out_track_info["track_path"] = tracks
|
||||
out_track_info["track_visibility"] = track_visibility
|
||||
return io.NodeOutput(out_track_info, track_length)
|
||||
|
||||
|
||||
class WanMoveConcatTrack(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -355,6 +468,7 @@ class WanMoveTrackToVideo(io.ComfyNode):
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Tracks.Input("tracks", optional=True),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
@ -370,7 +484,7 @@ class WanMoveTrackToVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
device=comfy.model_management.intermediate_device()
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
|
||||
if start_image is not None:
|
||||
@ -382,7 +496,7 @@ class WanMoveTrackToVideo(io.ComfyNode):
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
if tracks is not None:
|
||||
if tracks is not None and strength > 0.0:
|
||||
tracks_path = tracks["track_path"][:length] # [T, N, 2]
|
||||
num_tracks = tracks_path.shape[-2]
|
||||
|
||||
@ -390,7 +504,7 @@ class WanMoveTrackToVideo(io.ComfyNode):
|
||||
|
||||
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
|
||||
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
|
||||
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos)
|
||||
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
|
||||
else:
|
||||
concat_latent_image_pos = concat_latent_image
|
||||
|
||||
@ -414,6 +528,7 @@ class WanMoveExtension(ComfyExtension):
|
||||
WanMoveTracksFromCoords,
|
||||
WanMoveConcatTrack,
|
||||
WanMoveVisualizeTracks,
|
||||
GenerateTracks,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanMoveExtension:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user