mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
Compare commits
19 Commits
133184e182
...
90f17a27e1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90f17a27e1 | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
da2bfb5b0a | ||
|
|
c5a47a1692 | ||
|
|
908fd7d749 | ||
|
|
5495589db3 | ||
|
|
982876d59a | ||
|
|
338d9ae3bb | ||
|
|
eeb020b9b7 | ||
|
|
ae65433a60 | ||
|
|
fdebe18296 | ||
|
|
f8321eb57b | ||
|
|
93948e3fc5 | ||
|
|
e711aaf1a7 | ||
|
|
57ddb7fd13 | ||
|
|
17c92a9f28 | ||
|
|
36357bbcc3 | ||
|
|
f668c2e3c9 |
@ -53,6 +53,16 @@ try:
|
|||||||
repo.stash(ident)
|
repo.stash(ident)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("nothing to stash") # noqa: T201
|
print("nothing to stash") # noqa: T201
|
||||||
|
except:
|
||||||
|
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||||
|
repo.state_cleanup()
|
||||||
|
repo.index.read_tree(repo.head.peel().tree)
|
||||||
|
repo.index.write()
|
||||||
|
try:
|
||||||
|
repo.stash(ident)
|
||||||
|
except KeyError:
|
||||||
|
print("nothing to stash.") # noqa: T201
|
||||||
|
|
||||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -58,8 +58,13 @@ class InternalRoutes:
|
|||||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||||
|
|
||||||
directory = get_directory_by_type(directory_type)
|
directory = get_directory_by_type(directory_type)
|
||||||
|
|
||||||
|
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||||
|
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||||
|
return entry.is_file() and not entry.name.startswith('.')
|
||||||
|
|
||||||
sorted_files = sorted(
|
sorted_files = sorted(
|
||||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||||
key=lambda entry: -entry.stat().st_mtime
|
key=lambda entry: -entry.stat().st_mtime
|
||||||
)
|
)
|
||||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||||
|
|||||||
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
|
if solver_type not in {"phi_1", "phi_2"}:
|
||||||
|
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
if solver_type == "phi_1":
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
|
elif solver_type == "phi_2":
|
||||||
|
b2 = ei_h_phi_2(-h_eta) / r
|
||||||
|
b1 = ei_h_phi_1(-h_eta) - b2
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
segment_factor = (r - 1) * h * eta
|
segment_factor = (r - 1) * h * eta
|
||||||
sde_noise = sde_noise * segment_factor.exp()
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
|||||||
@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
|
n_control_layers=6,
|
||||||
|
control_in_dim=16,
|
||||||
|
additional_in_dim=0,
|
||||||
|
broken=False,
|
||||||
|
refiner_control=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.additional_in_dim = 0
|
self.broken = broken
|
||||||
self.control_in_dim = 16
|
self.additional_in_dim = additional_in_dim
|
||||||
|
self.control_in_dim = control_in_dim
|
||||||
n_refiner_layers = 2
|
n_refiner_layers = 2
|
||||||
self.n_control_layers = 6
|
self.n_control_layers = n_control_layers
|
||||||
self.control_layers = nn.ModuleList(
|
self.control_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
ZImageControlTransformerBlock(
|
ZImageControlTransformerBlock(
|
||||||
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
all_x_embedder = {}
|
all_x_embedder = {}
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
f_patch_size = 1
|
f_patch_size = 1
|
||||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
self.refiner_control = refiner_control
|
||||||
|
|
||||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
self.control_noise_refiner = nn.ModuleList(
|
if self.refiner_control:
|
||||||
[
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
JointTransformerBlock(
|
[
|
||||||
layer_id,
|
ZImageControlTransformerBlock(
|
||||||
dim,
|
layer_id,
|
||||||
n_heads,
|
dim,
|
||||||
n_kv_heads,
|
n_heads,
|
||||||
multiple_of,
|
n_kv_heads,
|
||||||
ffn_dim_multiplier,
|
multiple_of,
|
||||||
norm_eps,
|
ffn_dim_multiplier,
|
||||||
qk_norm,
|
norm_eps,
|
||||||
modulation=True,
|
qk_norm,
|
||||||
z_image_modulation=True,
|
block_id=layer_id,
|
||||||
operation_settings=operation_settings,
|
operation_settings=operation_settings,
|
||||||
)
|
)
|
||||||
for layer_id in range(n_refiner_layers)
|
for layer_id in range(n_refiner_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
z_image_modulation=True,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
|
|||||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
x_attn_mask = None
|
x_attn_mask = None
|
||||||
for layer in self.control_noise_refiner:
|
if not self.refiner_control:
|
||||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||||
|
|
||||||
return control_context
|
return control_context
|
||||||
|
|
||||||
|
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
|
if self.refiner_control:
|
||||||
|
if self.broken:
|
||||||
|
if layer_id == 0:
|
||||||
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
if layer_id > 0:
|
||||||
|
out = None
|
||||||
|
for i in range(1, len(self.control_layers)):
|
||||||
|
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
if out is None:
|
||||||
|
out = o
|
||||||
|
|
||||||
|
return (out, control_context)
|
||||||
|
else:
|
||||||
|
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
else:
|
||||||
|
return (None, control_context)
|
||||||
|
|
||||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
|
|||||||
@ -536,6 +536,7 @@ class NextDiT(nn.Module):
|
|||||||
bsz = len(x)
|
bsz = len(x)
|
||||||
pH = pW = self.patch_size
|
pH = pW = self.patch_size
|
||||||
device = x[0].device
|
device = x[0].device
|
||||||
|
orig_x = x
|
||||||
|
|
||||||
if self.pad_tokens_multiple is not None:
|
if self.pad_tokens_multiple is not None:
|
||||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||||
@ -572,13 +573,21 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
|
|
||||||
# refine context
|
# refine context
|
||||||
for layer in self.context_refiner:
|
for layer in self.context_refiner:
|
||||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||||
|
|
||||||
padded_img_mask = None
|
padded_img_mask = None
|
||||||
for layer in self.noise_refiner:
|
x_input = x
|
||||||
|
for i, layer in enumerate(self.noise_refiner):
|
||||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||||
|
if "noise_refiner" in patches:
|
||||||
|
for p in patches["noise_refiner"]:
|
||||||
|
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||||
|
if "img" in out:
|
||||||
|
x = out["img"]
|
||||||
|
|
||||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||||
mask = None
|
mask = None
|
||||||
@ -622,14 +631,15 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(img.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
|
img_input = img
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
if "double_block" in patches:
|
if "double_block" in patches:
|
||||||
for p in patches["double_block"]:
|
for p in patches["double_block"]:
|
||||||
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
if "img" in out:
|
if "img" in out:
|
||||||
img[:, cap_size[0]:] = out["img"]
|
img[:, cap_size[0]:] = out["img"]
|
||||||
if "txt" in out:
|
if "txt" in out:
|
||||||
|
|||||||
@ -259,8 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "__x0__" in state_dict_keys: # x0 pred
|
if "__x0__" in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
|
else:
|
||||||
|
dit_config["use_x0"] = False
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
|
|||||||
@ -454,6 +454,9 @@ class ModelPatcher:
|
|||||||
def set_model_post_input_patch(self, patch):
|
def set_model_post_input_patch(self, patch):
|
||||||
self.set_model_patch(patch, "post_input")
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
|
def set_model_noise_refiner_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "noise_refiner")
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
|
|||||||
27
comfy/ops.py
27
comfy/ops.py
@ -497,15 +497,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
if dtype is None:
|
||||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
dtype = MixedPrecisionOps._compute_dtype
|
||||||
|
|
||||||
|
self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
if bias:
|
self._has_bias = bias
|
||||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
self.tensor_class = None
|
self.tensor_class = None
|
||||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
@ -530,7 +529,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
if layer_conf is None:
|
if layer_conf is None:
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
dtype = self.factory_kwargs["dtype"]
|
||||||
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
||||||
|
if dtype != MixedPrecisionOps._compute_dtype:
|
||||||
|
self.comfy_cast_weights = True
|
||||||
|
if self._has_bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
else:
|
else:
|
||||||
self.quant_format = layer_conf.get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
if not self._full_precision_mm:
|
if not self._full_precision_mm:
|
||||||
@ -560,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._has_bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
for param_name in qconfig["parameters"]:
|
for param_name in qconfig["parameters"]:
|
||||||
param_key = f"{prefix}{param_name}"
|
param_key = f"{prefix}{param_name}"
|
||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
@ -581,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm:
|
if self._full_precision_mm:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
|
|||||||
@ -549,8 +549,10 @@ class VAE:
|
|||||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
|
||||||
|
|
||||||
|
|
||||||
# Hunyuan 3d v2 2.0 & 2.1
|
# Hunyuan 3d v2 2.0 & 2.1
|
||||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||||
|
|
||||||
|
|||||||
@ -541,7 +541,7 @@ class SD3(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.SD3
|
latent_format = latent_formats.SD3
|
||||||
|
|
||||||
memory_usage_factor = 1.2
|
memory_usage_factor = 1.6
|
||||||
|
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@ -965,7 +965,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
|||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
super().__init__(unet_config)
|
super().__init__(unet_config)
|
||||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.CosmosPredict2(self, device=device)
|
out = model_base.CosmosPredict2(self, device=device)
|
||||||
@ -1026,7 +1026,7 @@ class ZImage(Lumina2):
|
|||||||
"shift": 3.0,
|
"shift": 3.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 1.7
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
@ -1289,7 +1289,7 @@ class ChromaRadiance(Chroma):
|
|||||||
latent_format = comfy.latent_formats.ChromaRadiance
|
latent_format = comfy.latent_formats.ChromaRadiance
|
||||||
|
|
||||||
# Pixel-space model, no spatial compression for model input.
|
# Pixel-space model, no spatial compression for model input.
|
||||||
memory_usage_factor = 0.038
|
memory_usage_factor = 0.044
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ChromaRadiance(self, device=device)
|
return model_base.ChromaRadiance(self, device=device)
|
||||||
@ -1332,7 +1332,7 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
"shift": 2.6,
|
"shift": 2.6,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 1.65 #TODO
|
memory_usage_factor = 1.95 #TODO
|
||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
@ -1397,7 +1397,7 @@ class HunyuanImage21(HunyuanVideo):
|
|||||||
|
|
||||||
latent_format = latent_formats.HunyuanImage21
|
latent_format = latent_formats.HunyuanImage21
|
||||||
|
|
||||||
memory_usage_factor = 7.7
|
memory_usage_factor = 8.7
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@ -1488,7 +1488,7 @@ class Kandinsky5(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.HunyuanVideo
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
memory_usage_factor = 1.1 #TODO
|
memory_usage_factor = 1.25 #TODO
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
@ -1517,7 +1517,7 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
memory_usage_factor = 1.1 #TODO
|
memory_usage_factor = 1.25 #TODO
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Kandinsky5Image(self, device=device)
|
out = model_base.Kandinsky5Image(self, device=device)
|
||||||
|
|||||||
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
if quant_metadata is not None:
|
if quant_metadata is not None:
|
||||||
layers = quant_metadata["layers"]
|
layers = quant_metadata["layers"]
|
||||||
for k, v in layers.items():
|
for k, v in layers.items():
|
||||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
return state_dict, metadata
|
return state_dict, metadata
|
||||||
|
|||||||
@ -774,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
|
|||||||
class AudioEncoderOutput(ComfyTypeIO):
|
class AudioEncoderOutput(ComfyTypeIO):
|
||||||
Type = Any
|
Type = Any
|
||||||
|
|
||||||
|
@comfytype(io_type="TRACKS")
|
||||||
|
class Tracks(ComfyTypeIO):
|
||||||
|
class TrackDict(TypedDict):
|
||||||
|
track_path: torch.Tensor
|
||||||
|
track_visibility: torch.Tensor
|
||||||
|
Type = TrackDict
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||||
class MultiType:
|
class MultiType:
|
||||||
Type = Any
|
Type = Any
|
||||||
@ -1891,7 +1898,7 @@ class NodeOutput(_NodeOutputInternal):
|
|||||||
ui = data["ui"]
|
ui = data["ui"]
|
||||||
if "expand" in data:
|
if "expand" in data:
|
||||||
expand = data["expand"]
|
expand = data["expand"]
|
||||||
return cls(args=args, ui=ui, expand=expand)
|
return cls(*args, ui=ui, expand=expand)
|
||||||
|
|
||||||
def __getitem__(self, index) -> Any:
|
def __getitem__(self, index) -> Any:
|
||||||
return self.args[index]
|
return self.args[index]
|
||||||
@ -1970,6 +1977,7 @@ __all__ = [
|
|||||||
"SEGS",
|
"SEGS",
|
||||||
"AnyType",
|
"AnyType",
|
||||||
"MultiType",
|
"MultiType",
|
||||||
|
"Tracks",
|
||||||
# Dynamic Types
|
# Dynamic Types
|
||||||
"MatchType",
|
"MatchType",
|
||||||
# "DynamicCombo",
|
# "DynamicCombo",
|
||||||
|
|||||||
@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
|
|||||||
url: str = Field(..., description="URL for generated image")
|
url: str = Field(..., description="URL for generated image")
|
||||||
|
|
||||||
|
|
||||||
class OmniTaskStatusResults(BaseModel):
|
class TaskStatusResults(BaseModel):
|
||||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||||
images: list[TaskStatusImageResult] | None = Field(None)
|
images: list[TaskStatusImageResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniTaskStatusResponseData(BaseModel):
|
class TaskStatusResponseData(BaseModel):
|
||||||
created_at: int | None = Field(None, description="Task creation time")
|
created_at: int | None = Field(None, description="Task creation time")
|
||||||
updated_at: int | None = Field(None, description="Task update time")
|
updated_at: int | None = Field(None, description="Task update time")
|
||||||
task_status: str | None = None
|
task_status: str | None = None
|
||||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||||
task_id: str | None = Field(None, description="Task ID")
|
task_id: str | None = Field(None, description="Task ID")
|
||||||
task_result: OmniTaskStatusResults | None = Field(None)
|
task_result: TaskStatusResults | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniTaskStatusResponse(BaseModel):
|
class TaskStatusResponse(BaseModel):
|
||||||
code: int | None = Field(None, description="Error code")
|
code: int | None = Field(None, description="Error code")
|
||||||
message: str | None = Field(None, description="Error message")
|
message: str | None = Field(None, description="Error message")
|
||||||
request_id: str | None = Field(None, description="Request ID")
|
request_id: str | None = Field(None, description="Request ID")
|
||||||
data: OmniTaskStatusResponseData | None = Field(None)
|
data: TaskStatusResponseData | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniImageParamImage(BaseModel):
|
class OmniImageParamImage(BaseModel):
|
||||||
@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel):
|
|||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
n: int | None = Field(1, le=9)
|
n: int | None = Field(1, le=9)
|
||||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||||
|
|
||||||
|
|
||||||
|
class TextToVideoWithAudioRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-v2-6")
|
||||||
|
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToVideoWithAudioRequest(BaseModel):
|
||||||
|
model_name: str = Field(..., description="kling-v2-6")
|
||||||
|
image: str = Field(...)
|
||||||
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
mode: str = Field("pro")
|
||||||
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
|||||||
@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
|
|||||||
KlingSingleImageEffectModelName,
|
KlingSingleImageEffectModelName,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.kling_api import (
|
from comfy_api_nodes.apis.kling_api import (
|
||||||
|
ImageToVideoWithAudioRequest,
|
||||||
OmniImageParamImage,
|
OmniImageParamImage,
|
||||||
OmniParamImage,
|
OmniParamImage,
|
||||||
OmniParamVideo,
|
OmniParamVideo,
|
||||||
@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import (
|
|||||||
OmniProImageRequest,
|
OmniProImageRequest,
|
||||||
OmniProReferences2VideoRequest,
|
OmniProReferences2VideoRequest,
|
||||||
OmniProText2VideoRequest,
|
OmniProText2VideoRequest,
|
||||||
OmniTaskStatusResponse,
|
TaskStatusResponse,
|
||||||
|
TextToVideoWithAudioRequest,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -242,7 +244,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
|
|||||||
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||||
|
|
||||||
|
|
||||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
|
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
|
||||||
if response.code:
|
if response.code:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
@ -250,7 +252,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
|
|||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
max_poll_attempts=160,
|
max_poll_attempts=160,
|
||||||
)
|
)
|
||||||
@ -483,12 +485,12 @@ async def execute_image2video(
|
|||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||||
response_model=KlingImage2VideoResponse,
|
response_model=KlingImage2VideoResponse,
|
||||||
estimated_duration=AVERAGE_DURATION_I2V,
|
estimated_duration=AVERAGE_DURATION_I2V,
|
||||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
@ -834,7 +836,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProText2VideoRequest(
|
data=OmniProText2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -929,7 +931,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProFirstLastFrameRequest(
|
data=OmniProFirstLastFrameRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -997,7 +999,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1081,7 +1083,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1162,7 +1164,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProReferences2VideoRequest(
|
data=OmniProReferences2VideoRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1237,7 +1239,7 @@ class OmniProImageNode(IO.ComfyNode):
|
|||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
data=OmniProImageRequest(
|
data=OmniProImageRequest(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -1253,7 +1255,7 @@ class OmniProImageNode(IO.ComfyNode):
|
|||||||
final_response = await poll_op(
|
final_response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||||
response_model=OmniTaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||||
@ -1328,9 +1330,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
|||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="KlingImage2VideoNode",
|
node_id="KlingImage2VideoNode",
|
||||||
display_name="Kling Image to Video",
|
display_name="Kling Image(First Frame) to Video",
|
||||||
category="api node/video/Kling",
|
category="api node/video/Kling",
|
||||||
description="Kling Image to Video Node",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||||
@ -2034,6 +2035,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await image_result_to_node_output(images))
|
return IO.NodeOutput(await image_result_to_node_output(images))
|
||||||
|
|
||||||
|
|
||||||
|
class TextToVideoWithAudio(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingTextToVideoWithAudio",
|
||||||
|
display_name="Kling Text to Video with Audio",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||||
|
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||||
|
IO.Combo.Input("mode", options=["pro"]),
|
||||||
|
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||||
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
IO.Boolean.Input("generate_audio", default=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
prompt: str,
|
||||||
|
mode: str,
|
||||||
|
aspect_ratio: str,
|
||||||
|
duration: int,
|
||||||
|
generate_audio: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
data=TextToVideoWithAudioRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
mode=mode,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
duration=str(duration),
|
||||||
|
sound="on" if generate_audio else "off",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToVideoWithAudio(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> IO.Schema:
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="KlingImageToVideoWithAudio",
|
||||||
|
display_name="Kling Image(First Frame) to Video with Audio",
|
||||||
|
category="api node/video/Kling",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||||
|
IO.Image.Input("start_frame"),
|
||||||
|
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||||
|
IO.Combo.Input("mode", options=["pro"]),
|
||||||
|
IO.Combo.Input("duration", options=[5, 10]),
|
||||||
|
IO.Boolean.Input("generate_audio", default=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
start_frame: Input.Image,
|
||||||
|
prompt: str,
|
||||||
|
mode: str,
|
||||||
|
duration: int,
|
||||||
|
generate_audio: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=2500)
|
||||||
|
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
||||||
|
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
data=ImageToVideoWithAudioRequest(
|
||||||
|
model_name=model_name,
|
||||||
|
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
|
||||||
|
prompt=prompt,
|
||||||
|
mode=mode,
|
||||||
|
duration=str(duration),
|
||||||
|
sound="on" if generate_audio else "off",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response.code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||||
|
)
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||||
|
|
||||||
|
|
||||||
class KlingExtension(ComfyExtension):
|
class KlingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -2056,7 +2187,9 @@ class KlingExtension(ComfyExtension):
|
|||||||
OmniProImageToVideoNode,
|
OmniProImageToVideoNode,
|
||||||
OmniProVideoToVideoNode,
|
OmniProVideoToVideoNode,
|
||||||
OmniProEditVideoNode,
|
OmniProEditVideoNode,
|
||||||
# OmniProImageNode, # need support from backend
|
OmniProImageNode,
|
||||||
|
TextToVideoWithAudio,
|
||||||
|
ImageToVideoWithAudio,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
|||||||
get_sampler = execute
|
get_sampler = execute
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerSEEDS2(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerSEEDS2",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||||
|
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||||
|
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||||
|
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||||
|
sampler_name = "seeds_2"
|
||||||
|
sampler = comfy.samplers.ksampler(
|
||||||
|
sampler_name,
|
||||||
|
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||||
|
)
|
||||||
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
SamplerDPMAdaptative,
|
SamplerDPMAdaptative,
|
||||||
SamplerER_SDE,
|
SamplerER_SDE,
|
||||||
SamplerSASolver,
|
SamplerSASolver,
|
||||||
|
SamplerSEEDS2,
|
||||||
SplitSigmas,
|
SplitSigmas,
|
||||||
SplitSigmasDenoise,
|
SplitSigmasDenoise,
|
||||||
FlipSigmas,
|
FlipSigmas,
|
||||||
|
|||||||
@ -243,7 +243,13 @@ class ModelPatchLoader:
|
|||||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||||
sd = z_image_convert(sd)
|
sd = z_image_convert(sd)
|
||||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
config = {}
|
||||||
|
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
|
||||||
|
config['n_control_layers'] = 15
|
||||||
|
config['additional_in_dim'] = 17
|
||||||
|
config['refiner_control'] = True
|
||||||
|
config['broken'] = True
|
||||||
|
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
@ -297,56 +303,86 @@ class DiffSynthCnetPatch:
|
|||||||
return [self.model_patch]
|
return [self.model_patch]
|
||||||
|
|
||||||
class ZImageControlPatch:
|
class ZImageControlPatch:
|
||||||
def __init__(self, model_patch, vae, image, strength):
|
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
|
||||||
self.model_patch = model_patch
|
self.model_patch = model_patch
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.image = image
|
self.image = image
|
||||||
|
self.inpaint_image = inpaint_image
|
||||||
|
self.mask = mask
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.encoded_image = self.encode_latent_cond(image)
|
self.encoded_image = self.encode_latent_cond(image)
|
||||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
def encode_latent_cond(self, image):
|
def encode_latent_cond(self, control_image, inpaint_image=None):
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||||
return latent_image
|
if self.model_patch.model.additional_in_dim > 0:
|
||||||
|
if self.mask is None:
|
||||||
|
mask_ = torch.zeros_like(latent_image)[:, :1]
|
||||||
|
else:
|
||||||
|
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||||
|
if inpaint_image is None:
|
||||||
|
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||||
|
|
||||||
|
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||||
|
|
||||||
|
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||||
|
else:
|
||||||
|
return latent_image
|
||||||
|
|
||||||
def __call__(self, kwargs):
|
def __call__(self, kwargs):
|
||||||
x = kwargs.get("x")
|
x = kwargs.get("x")
|
||||||
img = kwargs.get("img")
|
img = kwargs.get("img")
|
||||||
|
img_input = kwargs.get("img_input")
|
||||||
txt = kwargs.get("txt")
|
txt = kwargs.get("txt")
|
||||||
pe = kwargs.get("pe")
|
pe = kwargs.get("pe")
|
||||||
vec = kwargs.get("vec")
|
vec = kwargs.get("vec")
|
||||||
block_index = kwargs.get("block_index")
|
block_index = kwargs.get("block_index")
|
||||||
|
block_type = kwargs.get("block_type", "")
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
|
inpaint_scaled = None
|
||||||
|
if self.inpaint_image is not None:
|
||||||
|
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
|
||||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
cnet_index = (block_index // 5)
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
cnet_index_float = (block_index / 5)
|
div = round(30 / cnet_blocks)
|
||||||
|
|
||||||
|
cnet_index = (block_index // div)
|
||||||
|
cnet_index_float = (block_index / div)
|
||||||
|
|
||||||
kwargs.pop("img") # we do ops in place
|
kwargs.pop("img") # we do ops in place
|
||||||
kwargs.pop("txt")
|
kwargs.pop("txt")
|
||||||
|
|
||||||
cnet_blocks = self.model_patch.model.n_control_layers
|
|
||||||
if cnet_index_float > (cnet_blocks - 1):
|
if cnet_index_float > (cnet_blocks - 1):
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
if block_type == "noise_refiner":
|
||||||
|
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
|
else:
|
||||||
|
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
|
|
||||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
if block_type == "noise_refiner":
|
||||||
next_layer = self.temp_data[0] + 1
|
next_layer = self.temp_data[0] + 1
|
||||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
|
if self.temp_data[1][0] is not None:
|
||||||
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
|
else:
|
||||||
|
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||||
|
next_layer = self.temp_data[0] + 1
|
||||||
|
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
|
|
||||||
if cnet_index_float == self.temp_data[0]:
|
if cnet_index_float == self.temp_data[0]:
|
||||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
if cnet_blocks == self.temp_data[0] + 1:
|
if cnet_blocks == self.temp_data[0] + 1:
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
@ -386,7 +422,9 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
|
||||||
|
model_patched.set_model_noise_refiner_patch(patch)
|
||||||
|
model_patched.set_model_double_block_patch(patch)
|
||||||
else:
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|||||||
535
comfy_extras/nodes_wanmove.py
Normal file
535
comfy_extras/nodes_wanmove.py
Normal file
@ -0,0 +1,535 @@
|
|||||||
|
import nodes
|
||||||
|
import node_helpers
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import numpy as np
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_extras.nodes_wan import parse_json_tracks
|
||||||
|
|
||||||
|
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
SKIP_ZERO = False
|
||||||
|
|
||||||
|
def get_pos_emb(
|
||||||
|
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
|
||||||
|
pos_emb_dim: int,
|
||||||
|
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
|
||||||
|
|
||||||
|
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
|
||||||
|
pos_k = pos_k.to(device, dtype)
|
||||||
|
if SKIP_ZERO:
|
||||||
|
pos_k = pos_k + 1
|
||||||
|
batch_size = pos_k.size(0)
|
||||||
|
|
||||||
|
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
|
||||||
|
# Expand denominator to match the shape needed for broadcasting
|
||||||
|
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
|
||||||
|
|
||||||
|
thetas = theta_func(denominator_expanded, pos_emb_dim)
|
||||||
|
|
||||||
|
# Ensure pos_k is in the correct shape for broadcasting
|
||||||
|
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
|
||||||
|
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
|
||||||
|
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
|
||||||
|
|
||||||
|
# Concatenate sine and cosine embeddings along the last dimension
|
||||||
|
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
|
||||||
|
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
def create_pos_embeddings(
|
||||||
|
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
|
||||||
|
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
|
||||||
|
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
|
||||||
|
height: int, # the height of the feature map
|
||||||
|
width: int, # the width of the feature map
|
||||||
|
track_num: int = -1, # the number of tracks to use
|
||||||
|
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
|
||||||
|
):
|
||||||
|
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
|
||||||
|
|
||||||
|
t, n, _ = pred_tracks.shape
|
||||||
|
t_down, h_down, w_down = downsample_ratios
|
||||||
|
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
|
||||||
|
|
||||||
|
if track_num == -1:
|
||||||
|
track_num = n
|
||||||
|
|
||||||
|
tracks_idx = torch.randperm(n)[:track_num]
|
||||||
|
tracks = pred_tracks[:, tracks_idx]
|
||||||
|
visibility = pred_visibility[:, tracks_idx]
|
||||||
|
|
||||||
|
for t_idx in range(0, t, t_down):
|
||||||
|
if t_down_strategy == "sample" or t_idx == 0:
|
||||||
|
cur_tracks = tracks[t_idx] # [N, 2]
|
||||||
|
cur_visibility = visibility[t_idx] # [N]
|
||||||
|
else:
|
||||||
|
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
|
||||||
|
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
|
||||||
|
|
||||||
|
for i in range(track_num):
|
||||||
|
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
|
||||||
|
continue
|
||||||
|
x, y = cur_tracks[i]
|
||||||
|
x, y = int(x // w_down), int(y // h_down)
|
||||||
|
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
|
||||||
|
|
||||||
|
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
|
||||||
|
|
||||||
|
def replace_feature(
|
||||||
|
vae_feature: torch.Tensor, # [B, C', T', H', W']
|
||||||
|
track_pos: torch.Tensor, # [B, N, T', 2]
|
||||||
|
strength: float = 1.0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
b, _, t, h, w = vae_feature.shape
|
||||||
|
assert b == track_pos.shape[0], "Batch size mismatch."
|
||||||
|
n = track_pos.shape[1]
|
||||||
|
|
||||||
|
# Shuffle the trajectory order
|
||||||
|
track_pos = track_pos[:, torch.randperm(n), :, :]
|
||||||
|
|
||||||
|
# Extract coordinates at time steps ≥ 1 and generate a valid mask
|
||||||
|
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
|
||||||
|
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
|
||||||
|
|
||||||
|
# Get all valid indices
|
||||||
|
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
|
||||||
|
num_valid = valid_indices.shape[0]
|
||||||
|
|
||||||
|
if num_valid == 0:
|
||||||
|
return vae_feature
|
||||||
|
|
||||||
|
# Decompose valid indices into each dimension
|
||||||
|
batch_idx = valid_indices[:, 0]
|
||||||
|
track_idx = valid_indices[:, 1]
|
||||||
|
t_rel = valid_indices[:, 2]
|
||||||
|
t_target = t_rel + 1 # Convert to original time step indices
|
||||||
|
|
||||||
|
# Extract target position coordinates
|
||||||
|
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
|
||||||
|
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
|
||||||
|
|
||||||
|
# Extract source position coordinates (t=0)
|
||||||
|
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
|
||||||
|
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
|
||||||
|
|
||||||
|
# Get source features and assign to target positions
|
||||||
|
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
|
||||||
|
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
|
||||||
|
|
||||||
|
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
|
||||||
|
|
||||||
|
|
||||||
|
return vae_feature
|
||||||
|
|
||||||
|
# Visualize functions
|
||||||
|
|
||||||
|
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
|
||||||
|
draw = ImageDraw.Draw(overlay, 'RGBA')
|
||||||
|
points = points[::-1]
|
||||||
|
|
||||||
|
# Compute total length
|
||||||
|
total_length = 0
|
||||||
|
segment_lengths = []
|
||||||
|
for i in range(len(points) - 1):
|
||||||
|
dx = points[i + 1][0] - points[i][0]
|
||||||
|
dy = points[i + 1][1] - points[i][1]
|
||||||
|
length = (dx * dx + dy * dy) ** 0.5
|
||||||
|
segment_lengths.append(length)
|
||||||
|
total_length += length
|
||||||
|
|
||||||
|
if total_length == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
accumulated_length = 0
|
||||||
|
|
||||||
|
# Draw the gradient polyline
|
||||||
|
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
|
||||||
|
segment_length = segment_lengths[idx]
|
||||||
|
steps = max(int(segment_length), 1)
|
||||||
|
|
||||||
|
for i in range(steps):
|
||||||
|
current_length = accumulated_length + (i / steps) * segment_length
|
||||||
|
ratio = current_length / total_length
|
||||||
|
|
||||||
|
alpha = int(255 * (1 - ratio) * opacity)
|
||||||
|
color = (*start_color, alpha)
|
||||||
|
|
||||||
|
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
|
||||||
|
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
|
||||||
|
|
||||||
|
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
|
||||||
|
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
|
||||||
|
|
||||||
|
accumulated_length += segment_length
|
||||||
|
|
||||||
|
|
||||||
|
def add_weighted(rgb, track):
|
||||||
|
rgb = np.array(rgb) # [H, W, C] "RGB"
|
||||||
|
track = np.array(track) # [H, W, C] "RGBA"
|
||||||
|
|
||||||
|
alpha = track[:, :, 3] / 255.0
|
||||||
|
alpha = np.stack([alpha] * 3, axis=-1)
|
||||||
|
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
|
||||||
|
|
||||||
|
return Image.fromarray(blend_img.astype(np.uint8))
|
||||||
|
|
||||||
|
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
|
||||||
|
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
|
||||||
|
|
||||||
|
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
|
||||||
|
tracks = tracks[0].long().detach().cpu().numpy()
|
||||||
|
if visibility is not None:
|
||||||
|
visibility = visibility[0].detach().cpu().numpy()
|
||||||
|
|
||||||
|
num_frames, height, width = video.shape[:3]
|
||||||
|
num_tracks = tracks.shape[1]
|
||||||
|
alpha_opacity = int(255 * opacity)
|
||||||
|
|
||||||
|
output_frames = []
|
||||||
|
for t in range(num_frames):
|
||||||
|
frame_rgb = video[t].astype(np.float32)
|
||||||
|
|
||||||
|
# Create a single RGBA overlay for all tracks in this frame
|
||||||
|
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||||
|
draw_overlay = ImageDraw.Draw(overlay)
|
||||||
|
|
||||||
|
polyline_data = []
|
||||||
|
|
||||||
|
# Draw all circles on a single overlay
|
||||||
|
for n in range(num_tracks):
|
||||||
|
if visibility is not None and visibility[t, n] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
track_coord = tracks[t, n]
|
||||||
|
color = color_map[n % len(color_map)]
|
||||||
|
circle_color = color + (alpha_opacity,)
|
||||||
|
|
||||||
|
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
|
||||||
|
fill=circle_color
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store polyline data for batch processing
|
||||||
|
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
|
||||||
|
if len(tracks_coord) > 1:
|
||||||
|
polyline_data.append((tracks_coord, color))
|
||||||
|
|
||||||
|
# Blend circles overlay once
|
||||||
|
overlay_np = np.array(overlay)
|
||||||
|
alpha = overlay_np[:, :, 3:4] / 255.0
|
||||||
|
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
||||||
|
|
||||||
|
# Draw all polylines on a single overlay
|
||||||
|
if polyline_data:
|
||||||
|
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||||
|
for tracks_coord, color in polyline_data:
|
||||||
|
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
|
||||||
|
|
||||||
|
# Blend polylines overlay once
|
||||||
|
polyline_np = np.array(polyline_overlay)
|
||||||
|
alpha = polyline_np[:, :, 3:4] / 255.0
|
||||||
|
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
||||||
|
|
||||||
|
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
|
||||||
|
|
||||||
|
return output_frames
|
||||||
|
|
||||||
|
|
||||||
|
class WanMoveVisualizeTracks(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanMoveVisualizeTracks",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("images"),
|
||||||
|
io.Tracks.Input("tracks", optional=True),
|
||||||
|
io.Int.Input("line_resolution", default=24, min=1, max=1024),
|
||||||
|
io.Int.Input("circle_size", default=12, min=1, max=128),
|
||||||
|
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
|
||||||
|
io.Int.Input("line_width", default=16, min=1, max=128),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
|
||||||
|
if tracks is None:
|
||||||
|
return io.NodeOutput(images)
|
||||||
|
|
||||||
|
track_path = tracks["track_path"].unsqueeze(0)
|
||||||
|
track_visibility = tracks["track_visibility"].unsqueeze(0)
|
||||||
|
images_in = images * 255.0
|
||||||
|
if images_in.shape[0] != track_path.shape[1]:
|
||||||
|
repeat_count = track_path.shape[1] // images.shape[0]
|
||||||
|
images_in = images_in.repeat(repeat_count, 1, 1, 1)
|
||||||
|
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
|
||||||
|
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
|
||||||
|
|
||||||
|
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
|
||||||
|
|
||||||
|
|
||||||
|
class WanMoveTracksFromCoords(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanMoveTracksFromCoords",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
|
||||||
|
io.Mask.Input("track_mask", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Tracks.Output(),
|
||||||
|
io.Int.Output(display_name="track_length"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
|
||||||
|
device=comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
tracks_data = parse_json_tracks(track_coords)
|
||||||
|
track_length = len(tracks_data[0])
|
||||||
|
|
||||||
|
track_list = [
|
||||||
|
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
|
||||||
|
for frame in range(len(tracks_data[0]))
|
||||||
|
]
|
||||||
|
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
||||||
|
|
||||||
|
num_tracks = tracks.shape[-2]
|
||||||
|
if track_mask is None:
|
||||||
|
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
||||||
|
|
||||||
|
out_track_info = {}
|
||||||
|
out_track_info["track_path"] = tracks
|
||||||
|
out_track_info["track_visibility"] = track_visibility
|
||||||
|
return io.NodeOutput(out_track_info, track_length)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateTracks(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="GenerateTracks",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=4096, step=16),
|
||||||
|
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
|
||||||
|
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
|
||||||
|
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
|
||||||
|
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
|
||||||
|
io.Int.Input("num_frames", default=81, min=1, max=1024),
|
||||||
|
io.Int.Input("num_tracks", default=5, min=1, max=100),
|
||||||
|
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
|
||||||
|
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
|
||||||
|
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||||
|
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||||
|
io.Combo.Input(
|
||||||
|
"interpolation",
|
||||||
|
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
|
||||||
|
tooltip="Controls the timing/speed of movement along the path.",
|
||||||
|
),
|
||||||
|
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Tracks.Output(),
|
||||||
|
io.Int.Output(display_name="track_length"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
|
||||||
|
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
|
||||||
|
device = comfy.model_management.intermediate_device()
|
||||||
|
track_length = num_frames
|
||||||
|
|
||||||
|
# normalized coordinates to pixel coordinates
|
||||||
|
start_x_px = start_x * width
|
||||||
|
start_y_px = start_y * height
|
||||||
|
mid_x_px = mid_x * width
|
||||||
|
mid_y_px = mid_y * height
|
||||||
|
end_x_px = end_x * width
|
||||||
|
end_y_px = end_y * height
|
||||||
|
|
||||||
|
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
|
||||||
|
|
||||||
|
t = torch.linspace(0, 1, num_frames, device=device)
|
||||||
|
if interpolation == "constant": # All points stay at start position
|
||||||
|
interp_values = torch.zeros_like(t)
|
||||||
|
elif interpolation == "linear":
|
||||||
|
interp_values = t
|
||||||
|
elif interpolation == "ease_in":
|
||||||
|
interp_values = t ** 2
|
||||||
|
elif interpolation == "ease_out":
|
||||||
|
interp_values = 1 - (1 - t) ** 2
|
||||||
|
elif interpolation == "ease_in_out":
|
||||||
|
interp_values = t * t * (3 - 2 * t)
|
||||||
|
|
||||||
|
if bezier: # apply interpolation to t for timing control along the bezier path
|
||||||
|
t_interp = interp_values
|
||||||
|
one_minus_t = 1 - t_interp
|
||||||
|
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
|
||||||
|
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
|
||||||
|
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
|
||||||
|
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
|
||||||
|
else: # calculate base x and y positions for each frame (center track)
|
||||||
|
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
|
||||||
|
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
|
||||||
|
# For non-bezier, tangent is constant (direction from start to end)
|
||||||
|
tangent_x = torch.full_like(t, end_x_px - start_x_px)
|
||||||
|
tangent_y = torch.full_like(t, end_y_px - start_y_px)
|
||||||
|
|
||||||
|
track_list = []
|
||||||
|
for frame_idx in range(num_frames):
|
||||||
|
# Calculate perpendicular direction at this frame
|
||||||
|
tx = tangent_x[frame_idx].item()
|
||||||
|
ty = tangent_y[frame_idx].item()
|
||||||
|
length = (tx ** 2 + ty ** 2) ** 0.5
|
||||||
|
|
||||||
|
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
|
||||||
|
perp_x = -ty / length
|
||||||
|
perp_y = tx / length
|
||||||
|
else: # If tangent is zero, spread horizontally
|
||||||
|
perp_x = 1.0
|
||||||
|
perp_y = 0.0
|
||||||
|
|
||||||
|
frame_tracks = []
|
||||||
|
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
|
||||||
|
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
|
||||||
|
track_x = x_positions[frame_idx].item() + perp_x * offset
|
||||||
|
track_y = y_positions[frame_idx].item() + perp_y * offset
|
||||||
|
frame_tracks.append([track_x, track_y])
|
||||||
|
track_list.append(frame_tracks)
|
||||||
|
|
||||||
|
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
||||||
|
|
||||||
|
if track_mask is None:
|
||||||
|
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
||||||
|
|
||||||
|
out_track_info = {}
|
||||||
|
out_track_info["track_path"] = tracks
|
||||||
|
out_track_info["track_visibility"] = track_visibility
|
||||||
|
return io.NodeOutput(out_track_info, track_length)
|
||||||
|
|
||||||
|
|
||||||
|
class WanMoveConcatTrack(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanMoveConcatTrack",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Tracks.Input("tracks_1"),
|
||||||
|
io.Tracks.Input("tracks_2", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Tracks.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
|
||||||
|
if tracks_2 is None:
|
||||||
|
return io.NodeOutput(tracks_1)
|
||||||
|
|
||||||
|
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
|
||||||
|
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
|
||||||
|
|
||||||
|
out_track_info = {}
|
||||||
|
out_track_info["track_path"] = tracks_out
|
||||||
|
out_track_info["track_visibility"] = mask_out
|
||||||
|
return io.NodeOutput(out_track_info)
|
||||||
|
|
||||||
|
|
||||||
|
class WanMoveTrackToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanMoveTrackToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Tracks.Input("tracks", optional=True),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("start_image"),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||||
|
device=comfy.model_management.intermediate_device()
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||||
|
image[:start_image.shape[0]] = start_image
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
|
if tracks is not None and strength > 0.0:
|
||||||
|
tracks_path = tracks["track_path"][:length] # [T, N, 2]
|
||||||
|
num_tracks = tracks_path.shape[-2]
|
||||||
|
|
||||||
|
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
|
||||||
|
|
||||||
|
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
|
||||||
|
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
|
||||||
|
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
|
||||||
|
else:
|
||||||
|
concat_latent_image_pos = concat_latent_image
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class WanMoveExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
WanMoveTrackToVideo,
|
||||||
|
WanMoveTracksFromCoords,
|
||||||
|
WanMoveConcatTrack,
|
||||||
|
WanMoveVisualizeTracks,
|
||||||
|
GenerateTracks,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> WanMoveExtension:
|
||||||
|
return WanMoveExtension()
|
||||||
1
nodes.py
1
nodes.py
@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
"nodes_kandinsky5.py",
|
"nodes_kandinsky5.py",
|
||||||
|
"nodes_wanmove.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.33.13
|
comfyui-frontend-package==1.34.8
|
||||||
comfyui-workflow-templates==0.7.54
|
comfyui-workflow-templates==0.7.54
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user