Compare commits

...

24 Commits

Author SHA1 Message Date
Homfen
3f7a4b5f62
Merge 6d9f2737d5 into 4e6a1b66a9 2026-01-25 04:04:15 +08:00
rattus
4e6a1b66a9
speed up and reduce VRAM of QWEN VAE and WAN (less so) (#12036)
Some checks are pending
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
* ops: introduce autopad for conv3d

This works around pytorch missing ability to causal pad as part of the
kernel and avoids massive weight duplications for padding.

* wan-vae: rework causal padding

This currently uses F.pad which takes a full deep copy and is liable to
be the VRAM peak. Instead, kick spatial padding back to the op and
consolidate the temporal padding with the cat for the cache.

* wan-vae: implement zero pad fast path

The WAN VAE is also QWEN where it is used single-image. These
convolutions are however zero padded 3d convolutions, which means the
VAE is actually just 2D down the last element of the conv weight in
the temporal dimension. Fast path this, to avoid adding zeros that
then just evaporate in convoluton math but cost computation.
2026-01-23 19:56:14 -05:00
comfyanonymous
9cf299a9f9
Make regular empty latent node work properly on flux 2 variants. (#12050) 2026-01-23 19:50:48 -05:00
ComfyUI Wiki
e89b22993a
Support ModelScope-Trainer/DiffSynth LoRA format for Flux.2 Klein models (#12042)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-23 15:27:49 -05:00
Jukka Seppänen
55bd606e92
LTX2: Refactor forward function for better VRAM efficiency and fix spatial inpainting (#12046)
* Disable timestep embed compression when inpainting

Spatial inpainting not compatible with the compression

* Reduce crossattn peak VRAM

* LTX2: Refactor forward function for better VRAM efficiency
2026-01-23 15:26:38 -05:00
Christian Byrne
79cdbc81cb
feat: Improve ResizeImageMaskNode UX with tooltips and search aliases (#12040)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
- Add search_aliases for discoverability: resize, scale, dimensions, etc.
- Add node description for hover tooltip
- Add tooltips to all inputs explaining their behavior
- Reorder options: most common (scale dimensions) first, most technical (scale to multiple) last

Addresses user feedback that 'resize' search returned nothing useful and
options like 'match size' and 'scale to multiple' were not self-explanatory.
2026-01-22 22:04:27 -08:00
comfyanonymous
f443b9f2ca
Revert "feat: Improve ResizeImageMaskNode UX with tooltips and search aliases…" (#12038)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This reverts commit 4e3038114a.
2026-01-22 23:02:37 -05:00
Christian Byrne
4e3038114a
feat: Improve ResizeImageMaskNode UX with tooltips and search aliases (#12013)
- Add search_aliases for discoverability: resize, scale, dimensions, etc.
- Add node description for hover tooltip
- Add tooltips to all inputs explaining their behavior
- Reorder options: most common (scale dimensions) first, most technical (scale to multiple) last

Addresses user feedback that 'resize' search returned nothing useful and
options like 'match size' and 'scale to multiple' were not self-explanatory.
2026-01-22 18:46:55 -08:00
Christian Byrne
bbb8864778
add search aliases to all nodes (#12035)
* feat: Add search_aliases field to node schema

Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate).

Changes:
- Add `search_aliases: list[str]` to V3 Schema
- Add `SEARCH_ALIASES` support for V1 nodes
- Include field in `/object_info` response
- Add aliases to high-priority core nodes

V1 usage:
```python
class MyNode:
    SEARCH_ALIASES = ["alt name", "synonym"]
```

V3 usage:
```python
io.Schema(
    node_id="MyNode",
    search_aliases=["alt name", "synonym"],
    ...
)
```

## Related PRs
- Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this)
- Docs: Comfy-Org/docs#XXXX (draft - merge after stable)

* Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1

* feat: add SEARCH_ALIASES for core nodes (#12016)

Add search aliases to 22 core nodes in nodes.py to improve node discoverability:
- Checkpoint/model loaders: CheckpointLoader, DiffusersLoader
- Conditioning nodes: ConditioningAverage, ConditioningSetArea, ConditioningSetMask, ConditioningZeroOut
- Style nodes: StyleModelApply
- Image nodes: LoadImageMask, LoadImageOutput, ImageBatch, ImageInvert, ImagePadForOutpaint
- Latent nodes: LoadLatent, SaveLatent, LatentBlend, LatentComposite, LatentCrop, LatentFlip, LatentFromBatch, LatentUpscale, LatentUpscaleBy, RepeatLatentBatch

* feat: add SEARCH_ALIASES for image, mask, and string nodes (#12017)

Add search aliases to nodes in comfy_extras for better discoverability:
- nodes_mask.py: mask manipulation nodes
- nodes_images.py: image processing nodes
- nodes_post_processing.py: post-processing effect nodes
- nodes_string.py: string manipulation nodes
- nodes_compositing.py: compositing nodes
- nodes_morphology.py: morphological operation nodes
- nodes_latent.py: latent space nodes

Uses search_aliases parameter in io.Schema() for v3 nodes.

* feat: add SEARCH_ALIASES for audio and video nodes (#12018)

Add search aliases to audio and video nodes for better discoverability:
- nodes_audio.py: audio loading, saving, and processing nodes
- nodes_video.py: video loading and processing nodes
- nodes_wan.py: WAN model nodes

Uses search_aliases parameter in io.Schema() for v3 nodes.

* feat: add SEARCH_ALIASES for model and misc nodes (#12019)

Add search aliases to model-related and miscellaneous nodes:
- Model nodes: nodes_model_merging.py, nodes_model_advanced.py, nodes_lora_extract.py
- Sampler nodes: nodes_custom_sampler.py, nodes_align_your_steps.py
- Control nodes: nodes_controlnet.py, nodes_attention_multiply.py, nodes_hooks.py
- Training nodes: nodes_train.py, nodes_dataset.py
- Utility nodes: nodes_logic.py, nodes_canny.py, nodes_differential_diffusion.py
- Architecture-specific: nodes_sd3.py, nodes_pixart.py, nodes_lumina2.py, nodes_kandinsky5.py, nodes_hidream.py, nodes_fresca.py, nodes_hunyuan3d.py
- Media nodes: nodes_load_3d.py, nodes_webcam.py, nodes_preview_any.py, nodes_wanmove.py

Uses search_aliases parameter in io.Schema() for v3 nodes, SEARCH_ALIASES class attribute for legacy nodes.
2026-01-22 18:36:58 -08:00
Omri Marom
d7f3241bf6
qwen_image: propagate attention mask. (#11966) 2026-01-22 20:02:31 -05:00
comfyanonymous
09a2e67151
Support loading flux 2 klein checkpoints saved with SaveCheckpoint. (#12033) 2026-01-22 18:20:48 -05:00
rattus
0fd1b78736
Reduce LTX2 VAE VRAM consumption (#12028)
* causal_video_ae: Remove attention ResNet

This attention_head_dim argument does not exist on this constructor so
this is dead code. Remove as generic attention mid VAE conflicts with
temporal roll.

* ltx-vae: consoldate causal/non-causal code paths

* ltx-vae: add cache rolling adder

* ltx-vae: use cached adder for resnet

* ltx-vae: Implement rolling VAE

Implement a temporal rolling VAE for the LTX2 VAE.

Usually when doing temporal rolling VAEs you can just chunk on time relying
on causality and cache behind you as you go. The LTX VAE is however
non-causal.

So go whole hog and implement per layer run ahead and backpressure between
the decoder layers using recursive state beween the layers.

Operations are ammended with temporal_cache_state{} which they can use to
hold any state then need for partial execution. Convolutions cache their
inputs behind the up to N-1 frames, and skip connections need to cache the
mismatch between convolution input and output that happens due to missing
future (non-causal) input.

Each call to run_up() processes a layer accross a range on input that
may or may not be complete. It goes depth first to process as much as
possible to try and digest frames to the final output ASAP. If layers run
out of input due to convolution losses, they simply return without action
effectively applying back-pressure to the earlier layers. As the earlier
layers do more work and caller deeper, the partial states are reconciled
and output continues to digest depth first as much as possible.

Chunking is done using a size quota rather than a fixed frame length and
any layer can initiate chunking, and multiple layers can chunk at different
granulatiries. This remove the old limitation of always having to process
1 latent frame to entirety and having to hold 8 full decoded frames as
the VRAM peak.
2026-01-22 16:54:18 -05:00
Terry Jia
8490eedadf
add ply & 3dgs format in 3d node (#11474)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-22 09:46:56 -08:00
Alexander Piskun
72f6be1690
chore(api-nodes): rename BriaImage and OpenAIGImage nodes (#12022)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-21 23:42:04 -08:00
Jukka Seppänen
16b9aabd52
Support Multi/InfiniteTalk (#10179)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* re-init

* Update model_multitalk.py

* whitespace...

* Update model_multitalk.py

* remove print

* this is redundant

* remove import

* Restore preview functionality

* Move block_idx to transformer_options

* Remove LoopingSamplerCustomAdvanced

* Remove looping functionality, keep extension functionality

* Update model_multitalk.py

* Handle ref_attn_mask with separate patch to avoid having to always return q and k from self_attn

* Chunk attention map calculation for multiple speakers to reduce peak VRAM usage

* Update model_multitalk.py

* Add ModelPatch type back

* Fix for latest upstream

* Use DynamicCombo for cleaner node

Basically just so that single_speaker mode hides mask inputs and 2nd audio input

* Update nodes_wan.py
2026-01-21 23:09:48 -05:00
Jukka Seppänen
245f6139b6
More targeted embedding_connector loading for LTX2 text encoder (#11992)
Reduces errors
2026-01-21 23:05:06 -05:00
Jukka Seppänen
3365ad18a5
Support LTX2 tiny vae (taeltx_2) (#11929) 2026-01-21 23:03:51 -05:00
Jedrzej Kosinski
f09904720d
Fix for edge case of EasyCache when conditionings change during a sampling run (like with timestep scheduling) (#12020) 2026-01-21 23:01:35 -05:00
comfyanonymous
abe2ec26a6
Support the Anima model. (#12012) 2026-01-21 19:44:28 -05:00
Christian Byrne
bdeac8897e
feat: Add search_aliases field to node schema (#12010)
* feat: Add search_aliases field to node schema

Adds `search_aliases` field to improve node discoverability. Users can define alternative search terms for nodes (e.g., "text concat" → StringConcatenate).

Changes:
- Add `search_aliases: list[str]` to V3 Schema
- Add `SEARCH_ALIASES` support for V1 nodes
- Include field in `/object_info` response
- Add aliases to high-priority core nodes

V1 usage:
```python
class MyNode:
    SEARCH_ALIASES = ["alt name", "synonym"]
```

V3 usage:
```python
io.Schema(
    node_id="MyNode",
    search_aliases=["alt name", "synonym"],
    ...
)
```

## Related PRs
- Frontend: Comfy-Org/ComfyUI_frontend#XXXX (draft - merge after this)
- Docs: Comfy-Org/docs#XXXX (draft - merge after stable)

* Propagate search_aliases through V3 Schema.get_v1_info to NodeInfoV1
2026-01-21 15:36:02 -08:00
Alexander Piskun
451af70154
fix(api-nodes-Vidu): allow passing up to 7 subjects in Vidu Reference node (#12002)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-21 04:03:45 -08:00
Markury
0fc15700be
Add LyCoris LoKr MLP layer support for Flux2 (#11997)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-20 23:18:33 -05:00
comfyanonymous
e755268e7b
Config for Qwen 3 0.6B model. (#11998) 2026-01-20 23:08:31 -05:00
homfen
6d9f2737d5 fix(SaveImage Node) 2025-08-12 12:20:49 +08:00
67 changed files with 1706 additions and 303 deletions

View File

@ -8,6 +8,7 @@ class LatentFormat:
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8
def process_in(self, latent):
return latent * self.scale_factor
@ -181,6 +182,7 @@ class Flux(SD3):
class Flux2(LatentFormat):
latent_channels = 128
spacial_downscale_ratio = 16
def __init__(self):
self.latent_rgb_factors =[
@ -749,6 +751,7 @@ class ACEAudio(LatentFormat):
class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1
def __init__(self):
self.latent_rgb_factors = [

202
comfy/ldm/anima/model.py Normal file
View File

@ -0,0 +1,202 @@
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
import torch
from torch import nn
import torch.nn.functional as F
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.rope_theta = 10000
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
super().__init__()
inner_dim = head_dim * n_heads
self.n_heads = n_heads
self.head_dim = head_dim
self.query_dim = query_dim
self.context_dim = context_dim
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
context = x if context is None else context
input_shape = x.shape[:-1]
q_shape = (*input_shape, self.n_heads, self.head_dim)
context_shape = context.shape[:-1]
kv_shape = (*context_shape, self.n_heads, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
if position_embeddings is not None:
assert position_embeddings_context is not None
cos, sin = position_embeddings
query_states = apply_rotary_pos_emb(query_states, cos, sin)
cos, sin = position_embeddings_context
key_states = apply_rotary_pos_emb(key_states, cos, sin)
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
def init_weights(self):
torch.nn.init.zeros_(self.o_proj.weight)
class TransformerBlock(nn.Module):
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
super().__init__()
self.use_self_attn = use_self_attn
if self.use_self_attn:
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.self_attn = Attention(
query_dim=model_dim,
context_dim=model_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.cross_attn = Attention(
query_dim=model_dim,
context_dim=source_dim,
n_heads=num_heads,
head_dim=model_dim//num_heads,
device=device,
dtype=dtype,
operations=operations,
)
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
self.mlp = nn.Sequential(
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
nn.GELU(),
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
if self.use_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
def init_weights(self):
torch.nn.init.zeros_(self.mlp[2].weight)
self.cross_attn.init_weights()
class LLMAdapter(nn.Module):
def __init__(
self,
source_dim=1024,
target_dim=1024,
model_dim=1024,
num_layers=6,
num_heads=16,
use_self_attn=True,
layer_norm=False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
if model_dim != target_dim:
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
else:
self.in_proj = nn.Identity()
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
self.blocks = nn.ModuleList([
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
])
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
if target_attention_mask is not None:
target_attention_mask = target_attention_mask.to(torch.bool)
if target_attention_mask.ndim == 2:
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
if source_attention_mask is not None:
source_attention_mask = source_attention_mask.to(torch.bool)
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
return self.norm(self.out_proj(x))
class Anima(MiniTrainDIT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
def preprocess_text_embeds(self, text_embeds, text_ids):
if text_ids is not None:
return self.llm_adapter(text_embeds, text_ids)
else:
return text_embeds

View File

@ -18,12 +18,12 @@ class CompressedTimestep:
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def forward(
self,
x: Tuple[torch.Tensor, torch.Tensor],
v_context=None,
a_context=None,
attention_mask=None,
v_timestep=None,
a_timestep=None,
v_pe=None,
a_pe=None,
v_cross_pe=None,
a_cross_pe=None,
v_cross_scale_shift_timestep=None,
a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None,
a_cross_gate_timestep=None,
transformer_options=None,
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
# video
if run_vx:
vshift_msa, vscale_msa, vgate_msa = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
)
# video self-attention
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
vx += self.attn2(
comfy.ldm.common_dit.rms_norm(vx),
context=v_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del vshift_msa, vscale_msa, vgate_msa
del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
del norm_vx
# video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
# audio
if run_ax:
ashift_msa, ascale_msa, agate_msa = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
)
# audio self-attention
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
ax += (
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
* agate_msa
)
ax += self.audio_attn2(
comfy.ldm.common_dit.rms_norm(ax),
context=a_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del ashift_msa, ascale_msa
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
del norm_ax
# audio cross-attention
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
del ashift_msa, ascale_msa, agate_msa
# Audio - Video cross attention.
# video - audio cross attention.
if run_a2v or run_v2a:
# norm3
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
a_cross_scale_shift_timestep,
a_cross_gate_timestep,
)
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
v_cross_scale_shift_timestep,
v_cross_gate_timestep,
)
# audio to video cross attention
if run_a2v:
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
+ shift_ca_video_hidden_states_a2v
)
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
+ shift_ca_audio_hidden_states_a2v
)
vx += (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=v_cross_pe,
k_pe=a_cross_pe,
transformer_options=transformer_options,
)
* gate_out_a2v
)
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
del gate_out_a2v
del scale_ca_video_hidden_states_a2v,\
shift_ca_video_hidden_states_a2v,\
scale_ca_audio_hidden_states_a2v,\
shift_ca_audio_hidden_states_a2v,\
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
del vx_scaled, ax_scaled
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
vx.addcmul_(a2v_out, gate_out_a2v)
del gate_out_a2v, a2v_out
# video to audio cross attention
if run_v2a:
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
+ shift_ca_audio_hidden_states_v2a
)
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
+ shift_ca_video_hidden_states_v2a
)
ax += (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=a_cross_pe,
k_pe=v_cross_pe,
transformer_options=transformer_options,
)
* gate_out_v2a
)
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
del gate_out_v2a
del scale_ca_video_hidden_states_v2a,\
shift_ca_video_hidden_states_v2a,\
scale_ca_audio_hidden_states_v2a,\
shift_ca_audio_hidden_states_v2a
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
del ax_scaled, vx_scaled
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
ax.addcmul_(v2a_out, gate_out_v2a)
del gate_out_v2a, v2a_out
del vx_norm3, ax_norm3
# video feedforward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
)
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
vx += self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
del vshift_mlp, vscale_mlp
ff_out = self.ff(vx_scaled)
del vx_scaled
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
vx.addcmul_(ff_out, vgate_mlp)
del vgate_mlp, ff_out
# audio feedforward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
)
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
ax += self.audio_ff(ax_scaled) * agate_mlp
del ashift_mlp, ascale_mlp
del ashift_mlp, ascale_mlp, agate_mlp
ff_out = self.audio_ff(ax_scaled)
del ax_scaled
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
ax.addcmul_(ff_out, agate_mlp)
del agate_mlp, ff_out
return vx, ax
@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
has_spatial_mask = False
if denoise_mask is not None:
# check if any frame has spatial variation (inpainting)
for frame_idx in range(denoise_mask.shape[2]):
frame_mask = denoise_mask[0, 0, frame_idx]
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
has_spatial_mask = True
break
[vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs
)
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax)
@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if orig_shape is not None and len(orig_shape) == 5:
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
)
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
]

View File

@ -1,11 +1,11 @@
from typing import Tuple, Union
import threading
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
padding_mode=spatial_padding_mode,
groups=groups,
)
self.temporal_cache_state={}
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
tid = threading.get_ident()
cached, is_end = self.temporal_cache_state.get(tid, (None, False))
if cached is None:
padding_length = self.time_kernel_size - 1
if not causal:
padding_length = padding_length // 2
if x.shape[2] == 0:
return x
cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
@property
def weight(self):

View File

@ -1,4 +1,5 @@
from __future__ import annotations
import threading
import torch
from torch import nn
from functools import partial
@ -6,12 +7,35 @@ import math
from einops import rearrange
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
if isinstance(m, CausalConv3d):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)
def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
def add_exchange_cache(dest, cache_in, new_input, dim=2):
if dest is not None:
if cache_in is not None:
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
lead_in_dest.add_(lead_in_source)
body, new_input = split2(new_input, dest.shape[dim], dim)
dest.add_(body)
return torch_cat_if_needed([cache_in, new_input], dim=dim)
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@ -205,7 +229,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
@ -254,6 +278,22 @@ class Encoder(nn.Module):
return sample
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module):
r"""
@ -341,18 +381,6 @@ class Decoder(nn.Module):
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
@ -428,8 +456,9 @@ class Decoder(nn.Module):
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
@ -437,6 +466,7 @@ class Decoder(nn.Module):
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = (
@ -445,24 +475,12 @@ class Decoder(nn.Module):
else lambda x: x
)
scaled_timestep = None
timestep_shift_scale = None
if self.timestep_conditioning:
assert (
timestep is not None
), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(),
resolution=None,
@ -483,16 +501,62 @@ class Decoder(nn.Module):
embedded_timestep.shape[-2],
embedded_timestep.shape[-1],
)
shift, scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
timestep_shift_scale = ada_values.unbind(dim=1)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
output = []
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
class UNetMidBlock3D(nn.Module):
"""
@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
tid = threading.get_ident()
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
y = self.conv(x, causal=causal)
y = rearrange(
y,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
y = y[:, :, 1:, :, :]
drop_first_conv = False
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
)
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2:
if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
drop_first_res = False
if y.shape[2] == 0:
y = None
cached = add_exchange_cache(y, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
else:
self.temporal_cache_state[tid] = (None, drop_first_conv, False)
return y
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
torch.randn(4, in_channels) / in_channels**0.5
)
self.temporal_cache_state={}
def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor:
@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
tid = threading.get_ident()
cached = self.temporal_cache_state.get(tid, None)
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
self.temporal_cache_state[tid] = cached
return output_tensor
return hidden_states
def patchify(x, patch_size_hw, patch_size_t=1):

View File

@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae():
import xformers.ops
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
return torch.cat(xl, dim)
else:
elif len(xl) == 1:
return xl[0]
else:
return None
def get_timestep_embedding(timesteps, embedding_dim):
"""

View File

@ -170,8 +170,14 @@ class Attention(nn.Module):
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
@ -430,6 +436,9 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]

View File

@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
patches = transformer_options.get("patches", {})
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
def qkv_fn_q(x):
@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module):
transformer_options=transformer_options,
)
if "attn1_patch" in patches:
for p in patches["attn1_patch"]:
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
x = self.o(x)
return x
@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
patches = transformer_options.get("patches", {})
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module):
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if "attn2_patch" in patches:
for p in patches["attn2_patch"]:
x = p({"x": x, "transformer_options": transformer_options})
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@ -488,7 +501,7 @@ class WanModel(torch.nn.Module):
self.blocks = nn.ModuleList([
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
for i in range(num_layers)
])
# head
@ -541,6 +554,7 @@ class WanModel(torch.nn.Module):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
# time embeddings
@ -738,6 +752,7 @@ class VaceWanModel(WanModel):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
# time embeddings

View File

@ -0,0 +1,500 @@
import torch
from einops import rearrange, repeat
import comfy
from comfy.ldm.modules.attention import optimized_attention
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q.transpose(1, 2) * scale
B, H, x_seqlens, K = visual_q.shape
x_ref_attn_maps = []
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
for i in range(0, x_seqlens, chunk_size):
end_i = min(i + chunk_size, x_seqlens)
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
# Apply softmax
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
attn_chunk = (attn_chunk - attn_max).exp()
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
attn_chunk = attn_chunk / (attn_sum + 1e-8)
# Apply mask and sum
masked_attn = attn_chunk * ref_target_mask
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
del attn_chunk, masked_attn
# Average across heads
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
x_ref_attn_maps.append(x_ref_attnmap)
del visual_q, ref_k
return torch.cat(x_ref_attn_maps, dim=0)
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
"""Args:
query (torch.tensor): B M H K
key (torch.tensor): B M H K
shape (tuple): (N_t, N_h, N_w)
ref_target_masks: [B, N_h * N_w]
"""
N_t, N_h, N_w = shape
x_seqlens = N_h * N_w
ref_k = ref_k[:, :x_seqlens]
_, seq_lens, heads, _ = visual_q.shape
class_num, _ = ref_target_masks.shape
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
split_chunk = heads // split_num
for i in range(split_num):
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
ref_target_masks
)
x_ref_attn_maps += x_ref_attn_maps_perhead
return x_ref_attn_maps / split_num
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
source_min, source_max = source_range
new_min, new_max = target_range
normalized = (column - source_min) / (source_max - source_min + epsilon)
scaled = normalized * (new_max - new_min) + new_min
return scaled
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def get_audio_embeds(encoded_audio, audio_start, audio_end):
audio_embs = []
human_num = len(encoded_audio)
audio_frames = encoded_audio[0].shape[0]
indices = (torch.arange(4 + 1) - 2) * 1
for human_idx in range(human_num):
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
pad_len = audio_end - audio_frames
pad_shape = list(encoded_audio[human_idx].shape)
pad_shape[0] = pad_len
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
else:
encoded_audio_in = encoded_audio[human_idx]
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
audio_embs.append(audio_emb)
return torch.cat(audio_embs, dim=0)
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
first_frame_audio_emb_s = audio_embs[:, :1, ...]
latter_frame_audio_emb = audio_embs[:, 1:, ...]
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
middle_index = audio_proj.seq_len // 2
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
audio_emb = torch.cat(audio_emb.split(1), dim=2)
return audio_emb
class RotaryPositionalEmbedding1D(torch.nn.Module):
def __init__(self,
head_dim,
):
super().__init__()
self.head_dim = head_dim
self.base = 10000
def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
freqs = freqs.to(pos_indices.device)
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
def forward(self, x, pos_indices):
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
x_ = x.float()
freqs_cis = freqs_cis.float().to(x.device)
cos, sin = freqs_cis.cos(), freqs_cis.sin()
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
x_ = (x_ * cos) + (rotate_half(x_) * sin)
return x_.type_as(x)
class SingleStreamAttention(torch.nn.Module):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.dim = dim
self.encoder_hidden_states_dim = encoder_hidden_states_dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
N_t, N_h, N_w = shape
expected_tokens = N_t * N_h * N_w
actual_tokens = x.shape[1]
x_extra = None
if actual_tokens != expected_tokens:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
B = x.shape[0]
S = N_h * N_w
x = x.view(B * N_t, S, self.dim)
# get q for hidden_state
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
kv = self.kv_linear(encoder_hidden_states)
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
x = optimized_attention(
q.transpose(1, 2),
encoder_k.transpose(1, 2),
encoder_v.transpose(1, 2),
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
# linear transform
x = self.proj(x.reshape(B * N_t, S, self.dim))
x = x.view(B, N_t * S, self.dim)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x
class SingleStreamMultiAttention(SingleStreamAttention):
def __init__(
self,
dim: int,
encoder_hidden_states_dim: int,
num_heads: int,
qkv_bias: bool,
class_range: int = 24,
class_interval: int = 4,
device=None, dtype=None, operations=None
) -> None:
super().__init__(
dim=dim,
encoder_hidden_states_dim=encoder_hidden_states_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
device=device,
dtype=dtype,
operations=operations
)
# Rotary-embedding layout parameters
self.class_interval = class_interval
self.class_range = class_range
self.max_humans = self.class_range // self.class_interval
# Constant bucket used for background tokens
self.rope_bak = int(self.class_range // 2)
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.Tensor,
shape=None,
x_ref_attn_map=None
) -> torch.Tensor:
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
# Single-speaker fall-through
if human_num <= 1:
return super().forward(x, encoder_hidden_states, shape)
N_t, N_h, N_w = shape
x_extra = None
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
x_extra = x[:, -N_h * N_w:, :]
x = x[:, :-N_h * N_w, :]
N_t = N_t - 1
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# Query projection
B, N, C = x.shape
q = self.q_linear(x)
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# Use `class_range` logic for 2 speakers
rope_h1 = (0, self.class_interval)
rope_h2 = (self.class_range - self.class_interval, self.class_range)
rope_bak = int(self.class_range // 2)
# Normalize and scale attention maps for each speaker
max_values = x_ref_attn_map.max(1).values[:, None, None]
min_values = x_ref_attn_map.min(1).values[:, None, None]
max_min_values = torch.cat([max_values, min_values], dim=2)
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
# Token-wise speaker dominance
max_indices = x_ref_attn_map.argmax(dim=0)
normalized_map = torch.stack([human1, human2, back], dim=1)
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
# Apply rotary to Q
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
q = self.rope_1d(q, normalized_pos)
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Keys / Values
_, N_a, _ = encoder_hidden_states.shape
encoder_kv = self.kv_linear(encoder_hidden_states)
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
encoder_k, encoder_v = encoder_kv.unbind(0)
# Rotary for keys assign centre of each speaker bucket to its context tokens
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
encoder_k = self.rope_1d(encoder_k, encoder_pos)
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
# Final attention
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
x = optimized_attention(
q.transpose(1, 2),
encoder_k.transpose(1, 2),
encoder_v.transpose(1, 2),
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
# Linear projection
x = x.reshape(B, N, C)
x = self.proj(x)
# Restore original layout
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
if x_extra is not None:
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
return x
class MultiTalkAudioProjModel(torch.nn.Module):
def __init__(
self,
seq_len: int = 5,
seq_len_vf: int = 12,
blocks: int = 12,
channels: int = 768,
intermediate_dim: int = 512,
out_dim: int = 768,
context_tokens: int = 32,
device=None, dtype=None, operations=None
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels
self.input_dim_vf = seq_len_vf * blocks * channels
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.out_dim = out_dim
# define multiple linear layers
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
def forward(self, audio_embeds, audio_embeds_vf):
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
B, _, _, S, C = audio_embeds.shape
# process audio of first frame
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
# process audio of latter frame
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
# first projection
audio_embeds = torch.relu(self.proj1(audio_embeds))
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
batch_size_c, N_t, C_a = audio_embeds_c.shape
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
# second projection
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
# normalization and reshape
context_tokens = self.norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
return context_tokens
class WanMultiTalkAttentionBlock(torch.nn.Module):
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
super().__init__()
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
class MultiTalkGetAttnMapPatch:
def __init__(self, ref_target_masks=None):
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
transformer_options = kwargs.get("transformer_options", {})
x = kwargs["x"]
if self.ref_target_masks is not None:
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
transformer_options["x_ref_attn_map"] = x_ref_attn_map
return x
class MultiTalkCrossAttnPatch:
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
self.model_patch = model_patch
self.audio_scale = audio_scale
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
transformer_options = kwargs.get("transformer_options", {})
block_idx = transformer_options.get("block_index", None)
x = kwargs["x"]
if block_idx is None:
return torch.zeros_like(x)
audio_embeds = transformer_options.get("audio_embeds")
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
norm_x, audio_embeds.to(x.dtype),
shape=transformer_options["grid_sizes"],
x_ref_attn_map=x_ref_attn_map
)
x = x + x_audio * self.audio_scale
return x
def models(self):
return [self.model_patch]
class MultiTalkApplyModelWrapper:
def __init__(self, init_latents):
self.init_latents = init_latents
def __call__(self, executor, x, *args, **kwargs):
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
samples = executor(x, *args, **kwargs)
return samples
class InfiniteTalkOuterSampleWrapper:
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
self.motion_frames_latent = motion_frames_latent
self.model_patch = model_patch
self.is_extend = is_extend
def __call__(self, executor, *args, **kwargs):
model_patcher = executor.class_obj.model_patcher
model_options = executor.class_obj.model_options
process_latent_in = model_patcher.model.process_latent_in
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
if self.motion_frames_latent is not None:
wrappers = model_options["transformer_options"]["wrappers"]
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
# run the sampling process
result = executor(*args, **kwargs)
# insert motion frames before decoding
if self.is_extend:
overlap = self.motion_frames_latent.shape[2]
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
return result
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
if self.motion_frames_latent is not None:
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
return self

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
self._padding = 2 * self.padding[0]
self.padding = (0, self.padding[1], self.padding[2])
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
if cache_x is None and x.shape[2] == 1:
#Fast path - the op will pad for use by truncating the weight
#and save math on a pile of zeros.
return super().forward(x, autopad="causal_zero")
if self._padding > 0:
padding_needed = self._padding
if cache_x is not None:
cache_x = cache_x.to(x.device)
padding_needed = max(0, padding_needed - cache_x.shape[2])
padding_shape = list(x.shape)
padding_shape[2] = padding_needed
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
del cache_x
x = F.pad(x, padding)
return super().forward(x)

View File

@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
key_map[k[:-len(".weight")]] = to #DiffSynth lora format
for k in sdk:
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
if k.endswith(".weight") and ".linear1." in k:

View File

@ -49,6 +49,7 @@ import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.model_management
import comfy.patcher_extension
@ -1147,6 +1148,27 @@ class CosmosPredict2(BaseModel):
sigma = (sigma / (sigma + 1))
return latent_image / (1.0 - sigma)
class Anima(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
t5xxl_ids = kwargs.get("t5xxl_ids", None)
t5xxl_weights = kwargs.get("t5xxl_weights", None)
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
if cross_attn.shape[1] < 512:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
@ -1556,6 +1578,9 @@ class QwenImage(BaseModel):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

View File

@ -550,6 +550,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
dit_config = {}
dit_config["image_model"] = "cosmos_predict2"
if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "anima"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128

View File

@ -203,7 +203,9 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, *args, **kwargs):
def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs):
if autopad == "causal_zero":
weight = weight[:, :, -input.shape[2]:, :, :]
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
@ -212,15 +214,15 @@ class disable_weight_init:
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
def forward_comfy_cast_weights(self, input, autopad=None):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
x = self._conv_forward(input, weight, bias, autopad=autopad)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)

View File

@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None):
return noises
def fix_empty_latent_channels(model, latent_image):
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if torch.count_nonzero(latent_image) == 0:
if latent_format.latent_channels != latent_image.shape[1]:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image

View File

@ -57,6 +57,7 @@ import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.model_patcher
import comfy.lora
@ -635,14 +636,13 @@ class VAE:
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
if self.latent_channels == 48: # Wan 2.2
if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_input = self.process_output = lambda image: image
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
@ -1048,6 +1048,7 @@ class TEModel(Enum):
GEMMA_3_12B = 18
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
def detect_te_model(sd):
@ -1093,6 +1094,8 @@ def detect_te_model(sd):
return TEModel.QWEN3_2B
elif weight.shape[0] == 4096:
return TEModel.QWEN3_8B
elif weight.shape[0] == 1024:
return TEModel.QWEN3_06B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@ -1233,6 +1236,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:

View File

@ -23,6 +23,7 @@ import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.anima
from . import supported_models_base
from . import latent_formats
@ -770,10 +771,24 @@ class Flux2(Flux):
return out
def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
if len(detect) > 0:
detect["model_type"] = "qwen3_4b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
if len(detect) > 0:
detect["model_type"] = "qwen3_8b"
return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect))
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref))
if len(detect) > 0:
if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict:
detect["pruned"] = True
return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect))
return None
class GenmoMochi(supported_models_base.BASE):
unet_config = {
@ -992,6 +1007,36 @@ class CosmosT2IPredict2(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
class Anima(supported_models_base.BASE):
unet_config = {
"image_model": "anima",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Anima(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
@ -1551,6 +1596,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
class TAEHV(nn.Module):
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True),
latent_format=None, show_progress_bar=False):
super().__init__()
self.image_channels = 3
self.patch_size = 1
@ -124,6 +125,9 @@ class TAEHV(nn.Module):
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2
elif self.latent_channels == 128: # LTX2
self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True)
if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1
@ -131,41 +135,52 @@ class TAEHV(nn.Module):
self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func,
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow))
self.frames_to_trim = self.t_upscale - 1
self._show_progress_bar = show_progress_bar
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1:
x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
if self.patch_size > 1:
B, T, C, H, W = x.shape
x = x.reshape(B * T, C, H, W)
x = F.pixel_unshuffle(x, self.patch_size)
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
if x.shape[1] % self.t_downscale != 0:
# pad at end to multiple of t_downscale
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
return self.process_out(x)
def decode(self, x, **kwargs):
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1:

View File

@ -0,0 +1,61 @@
from transformers import Qwen2Tokenizer, T5TokenizerFast
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
class AnimaTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
return self.t5xxl.untokenize(token_weight_pair)
def state_dict(self):
return {}
class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class AnimaTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out = super().encode_token_weights(token_weight_pairs)
out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int)
out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0])))
return out
def te(dtype_llama=None, llama_quantization_metadata=None):
class AnimaTEModel_(AnimaTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return AnimaTEModel_

View File

@ -10,9 +10,11 @@ import comfy.utils
def llama_detect(state_dict, prefix=""):
out = {}
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)]
for norm_key in norm_keys:
if norm_key in state_dict:
out["dtype_llama"] = state_dict[norm_key].dtype
break
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:

View File

@ -77,6 +77,28 @@ class Qwen25_3BConfig:
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_06BConfig:
vocab_size: int = 151936
hidden_size: int = 1024
intermediate_size: int = 3072
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
@ -641,6 +663,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_06B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -118,9 +118,18 @@ class LTXAVTEModel(torch.nn.Module):
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
missing, unexpected = self.load_state_dict(sdo, strict=False)
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
return (missing, unexpected)
missing_all = []
unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False)
missing_all.extend([f"{prefix}{k}" for k in missing])
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
return (missing_all, unexpected_all)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0

View File

@ -611,6 +611,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr
"ff.linear_in.bias": "img_mlp.0.bias",
"ff.linear_out.weight": "img_mlp.2.weight",
"ff.linear_out.bias": "img_mlp.2.bias",
"ff_context.linear_in.weight": "txt_mlp.0.weight",
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",

View File

@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO):
Type = Any
@comfytype(io_type="MODEL_PATCH")
class MODEL_PATCH(ComfyTypeIO):
class ModelPatch(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER")
@ -1249,6 +1249,7 @@ class NodeInfoV1:
experimental: bool=None
api_node: bool=None
price_badge: dict | None = None
search_aliases: list[str]=None
@dataclass
class NodeInfoV3:
@ -1346,6 +1347,8 @@ class Schema:
hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
search_aliases: list[str] = field(default_factory=list)
"""Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming."""
is_input_list: bool = False
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
@ -1483,6 +1486,7 @@ class Schema:
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
search_aliases=self.search_aliases if self.search_aliases else None,
)
return info
@ -2034,6 +2038,7 @@ __all__ = [
"ControlNet",
"Vae",
"Model",
"ModelPatch",
"ClipVision",
"ClipVisionOutput",
"AudioEncoder",

View File

@ -24,7 +24,7 @@ class BriaImageEditNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="BriaImageEditNode",
display_name="Bria Image Edit",
display_name="Bria FIBO Image Edit",
category="api node/image/Bria",
description="Edit images using Bria latest model",
inputs=[

View File

@ -364,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1",
display_name="OpenAI GPT Image 1.5",
category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
inputs=[
IO.String.Input(
"prompt",
@ -429,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
IO.Combo.Input(
"model",
options=["gpt-image-1", "gpt-image-1.5"],
default="gpt-image-1.5",
optional=True,
),
],

View File

@ -703,7 +703,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
"subjects",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_images"),
names=["subject1", "subject2", "subject3"],
names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"],
min=1,
),
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
@ -738,7 +738,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
control_after_generate=True,
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
IO.Combo.Input("resolution", options=["720p"]),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],

View File

@ -28,6 +28,7 @@ class AlignYourStepsScheduler(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AlignYourStepsScheduler",
search_aliases=["AYS scheduler"],
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),

View File

@ -71,6 +71,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPAttentionMultiply",
search_aliases=["clip attention scale", "text encoder attention"],
category="_for_testing/attention_experiments",
inputs=[
io.Clip.Input("clip"),

View File

@ -69,6 +69,7 @@ class VAEEncodeAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VAEEncodeAudio",
search_aliases=["audio to latent"],
display_name="VAE Encode Audio",
category="latent/audio",
inputs=[
@ -97,6 +98,7 @@ class VAEDecodeAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeAudio",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio",
category="latent/audio",
inputs=[
@ -122,6 +124,7 @@ class SaveAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveAudio",
search_aliases=["export flac"],
display_name="Save Audio (FLAC)",
category="audio",
inputs=[
@ -146,6 +149,7 @@ class SaveAudioMP3(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioMP3",
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
category="audio",
inputs=[
@ -173,6 +177,7 @@ class SaveAudioOpus(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioOpus",
search_aliases=["export opus"],
display_name="Save Audio (Opus)",
category="audio",
inputs=[
@ -200,6 +205,7 @@ class PreviewAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="PreviewAudio",
search_aliases=["play audio"],
display_name="Preview Audio",
category="audio",
inputs=[
@ -259,6 +265,7 @@ class LoadAudio(IO.ComfyNode):
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return IO.Schema(
node_id="LoadAudio",
search_aliases=["import audio", "open audio", "audio file"],
display_name="Load Audio",
category="audio",
inputs=[
@ -296,6 +303,7 @@ class RecordAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="RecordAudio",
search_aliases=["microphone input", "audio capture", "voice input"],
display_name="Record Audio",
category="audio",
inputs=[
@ -320,6 +328,7 @@ class TrimAudioDuration(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TrimAudioDuration",
search_aliases=["cut audio", "audio clip", "shorten audio"],
display_name="Trim Audio Duration",
description="Trim audio tensor into chosen time range.",
category="audio",
@ -372,6 +381,7 @@ class SplitAudioChannels(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SplitAudioChannels",
search_aliases=["stereo to mono"],
display_name="Split Audio Channels",
description="Separates the audio into left and right channels.",
category="audio",
@ -472,6 +482,7 @@ class AudioConcat(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="AudioConcat",
search_aliases=["join audio", "combine audio", "append audio"],
display_name="Audio Concat",
description="Concatenates the audio1 to audio2 in the specified direction.",
category="audio",
@ -519,6 +530,7 @@ class AudioMerge(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="AudioMerge",
search_aliases=["mix audio", "overlay audio", "layer audio"],
display_name="Audio Merge",
description="Combine two audio tracks by overlaying their waveforms.",
category="audio",
@ -579,6 +591,7 @@ class AudioAdjustVolume(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="AudioAdjustVolume",
search_aliases=["audio gain", "loudness", "audio level"],
display_name="Audio Adjust Volume",
category="audio",
inputs=[
@ -614,6 +627,7 @@ class EmptyAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="EmptyAudio",
search_aliases=["blank audio"],
display_name="Empty Audio",
category="audio",
inputs=[

View File

@ -10,6 +10,7 @@ class Canny(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Canny",
search_aliases=["edge detection", "outline", "contour detection", "line art"],
category="image/preprocessors",
inputs=[
io.Image.Input("image"),

View File

@ -109,6 +109,7 @@ class PorterDuffImageComposite(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PorterDuffImageComposite",
search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
display_name="Porter-Duff Image Composite",
category="mask/compositing",
inputs=[
@ -165,6 +166,7 @@ class SplitImageWithAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SplitImageWithAlpha",
search_aliases=["extract alpha", "separate transparency", "remove alpha"],
display_name="Split Image with Alpha",
category="mask/compositing",
inputs=[
@ -188,6 +190,7 @@ class JoinImageWithAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="JoinImageWithAlpha",
search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
display_name="Join Image with Alpha",
category="mask/compositing",
inputs=[

View File

@ -38,6 +38,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ControlNetInpaintingAliMamaApply",
search_aliases=["masked controlnet"],
category="conditioning/controlnet",
inputs=[
io.Conditioning.Input("positive"),

View File

@ -297,6 +297,7 @@ class ExtendIntermediateSigmas(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ExtendIntermediateSigmas",
search_aliases=["interpolate sigmas"],
category="sampling/custom_sampling/sigmas",
inputs=[
io.Sigmas.Input("sigmas"),
@ -740,7 +741,7 @@ class SamplerCustom(io.ComfyNode):
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
latent["samples"] = latent_image
if not add_noise:
@ -759,6 +760,7 @@ class SamplerCustom(io.ComfyNode):
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
@ -856,6 +858,7 @@ class DualCFGGuider(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="DualCFGGuider",
search_aliases=["dual prompt guidance"],
category="sampling/custom_sampling/guiders",
inputs=[
io.Model.Input("model"),
@ -883,6 +886,7 @@ class DisableNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="DisableNoise",
search_aliases=["zero noise"],
category="sampling/custom_sampling/noise",
inputs=[],
outputs=[io.Noise.Output()]
@ -936,7 +940,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
latent["samples"] = latent_image
noise_mask = None
@ -951,6 +955,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
samples = samples.to(comfy.model_management.intermediate_device())
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
if "x0" in x0_output:
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
@ -1019,6 +1024,7 @@ class ManualSigmas(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"],
category="_for_testing/custom_sampling",
is_experimental=True,
inputs=[

View File

@ -1223,11 +1223,11 @@ class ResolutionBucket(io.ComfyNode):
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MakeTrainingDataset",
search_aliases=["encode dataset"],
display_name="Make Training Dataset",
category="dataset",
is_experimental=True,
@ -1309,11 +1309,11 @@ class MakeTrainingDataset(io.ComfyNode):
class SaveTrainingDataset(io.ComfyNode):
"""Save encoded training dataset (latents + conditioning) to disk."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveTrainingDataset",
search_aliases=["export training data"],
display_name="Save Training Dataset",
category="dataset",
is_experimental=True,
@ -1410,11 +1410,11 @@ class SaveTrainingDataset(io.ComfyNode):
class LoadTrainingDataset(io.ComfyNode):
"""Load encoded training dataset from disk."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data"],
display_name="Load Training Dataset",
category="dataset",
is_experimental=True,

View File

@ -11,6 +11,7 @@ class DifferentialDiffusion(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="DifferentialDiffusion",
search_aliases=["inpaint gradient", "variable denoise strength"],
display_name="Differential Diffusion",
category="_for_testing",
inputs=[

View File

@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
do_easycache = easycache.should_do_easycache(sigmas)
if do_easycache:
easycache.check_metadata(x)
# if there isn't a cache diff for current conds, we cannot skip this step
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
# if first cond marked this step for skipping, skip it and use appropriate cached values
if easycache.skip_current_step:
if easycache.skip_current_step and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
return easycache.apply_cache_diff(x, uuids)
@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
@ -240,6 +242,9 @@ class EasyCacheHolder:
return to_return.clone()
return to_return
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
self.total_steps_skipped += 1

View File

@ -58,6 +58,7 @@ class FreSca(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FreSca",
search_aliases=["frequency guidance"],
display_name="FreSca",
category="_for_testing",
description="Applies frequency-dependent scaling to the guidance",

View File

@ -38,6 +38,7 @@ class CLIPTextEncodeHiDream(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeHiDream",
search_aliases=["hidream prompt"],
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),

View File

@ -259,6 +259,7 @@ class SetClipHooks:
return (clip,)
class ConditioningTimestepsRange:
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
NodeId = 'ConditioningTimestepsRange'
NodeName = 'Timesteps Range'
@classmethod
@ -468,6 +469,7 @@ class SetHookKeyframes:
return (hooks,)
class CreateHookKeyframe:
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
NodeId = 'CreateHookKeyframe'
NodeName = 'Create Hook Keyframe'
@classmethod
@ -497,6 +499,7 @@ class CreateHookKeyframe:
return (prev_hook_kf,)
class CreateHookKeyframesInterpolated:
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
NodeId = 'CreateHookKeyframesInterpolated'
NodeName = 'Create Hook Keyframes Interp.'
@classmethod
@ -544,6 +547,7 @@ class CreateHookKeyframesInterpolated:
return (prev_hook_kf,)
class CreateHookKeyframesFromFloats:
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
NodeId = 'CreateHookKeyframesFromFloats'
NodeName = 'Create Hook Keyframes From Floats'
@classmethod
@ -618,6 +622,7 @@ class SetModelHooksOnCond:
# Combine Hooks
#------------------------------------------
class CombineHooks:
SEARCH_ALIASES = ["merge hooks"]
NodeId = 'CombineHooks2'
NodeName = 'Combine Hooks [2]'
@classmethod

View File

@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
search_aliases=["export 3d model", "save mesh"],
category="3d",
is_output_node=True,
inputs=[

View File

@ -22,6 +22,7 @@ class ImageCrop(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
inputs=[
@ -51,6 +52,7 @@ class RepeatImageBatch(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="RepeatImageBatch",
search_aliases=["duplicate image", "clone image"],
category="image/batch",
inputs=[
IO.Image.Input("image"),
@ -72,6 +74,7 @@ class ImageFromBatch(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageFromBatch",
search_aliases=["select image", "pick from batch", "extract image"],
category="image/batch",
inputs=[
IO.Image.Input("image"),
@ -97,6 +100,7 @@ class ImageAddNoise(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageAddNoise",
search_aliases=["film grain"],
category="image",
inputs=[
IO.Image.Input("image"),
@ -194,11 +198,11 @@ class SaveAnimatedPNG(IO.ComfyNode):
class ImageStitch(IO.ComfyNode):
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageStitch",
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
display_name="Image Stitch",
description="Stitches image2 to image1 in the specified direction.\n"
"If image2 is not provided, returns image1 unchanged.\n"
@ -369,11 +373,11 @@ class ImageStitch(IO.ComfyNode):
class ResizeAndPadImage(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ResizeAndPadImage",
search_aliases=["fit to size"],
category="image/transform",
inputs=[
IO.Image.Input("image"),
@ -420,11 +424,11 @@ class ResizeAndPadImage(IO.ComfyNode):
class SaveSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveSVGNode",
search_aliases=["export vector", "save vector graphics"],
description="Save SVG files on disk.",
category="image/save",
inputs=[
@ -492,11 +496,11 @@ class SaveSVGNode(IO.ComfyNode):
class GetImageSize(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GetImageSize",
search_aliases=["dimensions", "resolution", "image info"],
display_name="Get Image Size",
description="Returns width and height of the image, and passes it through unchanged.",
category="image",
@ -527,11 +531,11 @@ class GetImageSize(IO.ComfyNode):
class ImageRotate(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageRotate",
search_aliases=["turn", "flip orientation"],
category="image/transform",
inputs=[
IO.Image.Input("image"),
@ -557,11 +561,11 @@ class ImageRotate(IO.ComfyNode):
class ImageFlip(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageFlip",
search_aliases=["mirror", "reflect"],
category="image/transform",
inputs=[
IO.Image.Input("image"),

View File

@ -104,6 +104,7 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeKandinsky5",
search_aliases=["kandinsky prompt"],
category="advanced/conditioning/kandinsky5",
inputs=[
io.Clip.Input("clip"),

View File

@ -21,6 +21,7 @@ class LatentAdd(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentAdd",
search_aliases=["combine latents", "sum latents"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -47,6 +48,7 @@ class LatentSubtract(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentSubtract",
search_aliases=["difference latent", "remove features"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -73,6 +75,7 @@ class LatentMultiply(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentMultiply",
search_aliases=["scale latent", "amplify latent", "latent gain"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@ -96,6 +99,7 @@ class LatentInterpolate(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentInterpolate",
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -134,6 +138,7 @@ class LatentConcat(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentConcat",
search_aliases=["join latents", "stitch latents"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
@ -173,6 +178,7 @@ class LatentCut(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentCut",
search_aliases=["crop latent", "slice latent", "extract region"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@ -213,6 +219,7 @@ class LatentCutToBatch(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentCutToBatch",
search_aliases=["slice to batch", "split latent", "tile latent"],
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
@ -254,6 +261,7 @@ class LatentBatch(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentBatch",
search_aliases=["combine latents", "merge latents", "join latents"],
category="latent/batch",
is_deprecated=True,
inputs=[
@ -310,6 +318,7 @@ class LatentApplyOperation(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperation",
search_aliases=["transform latent"],
category="latent/advanced/operations",
is_experimental=True,
inputs=[
@ -365,6 +374,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentOperationTonemapReinhard",
search_aliases=["hdr latent"],
category="latent/advanced/operations",
is_experimental=True,
inputs=[

View File

@ -24,7 +24,7 @@ class Load3D(IO.ComfyNode):
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
]
return IO.Schema(
node_id="Load3D",
@ -75,6 +75,7 @@ class Preview3D(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Preview3D",
search_aliases=["view mesh", "3d viewer"],
display_name="Preview 3D & Animation",
category="3d",
is_experimental=True,

View File

@ -224,6 +224,7 @@ class ConvertStringToComboNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="logic",
inputs=[io.String.Input("string")],
@ -239,6 +240,7 @@ class InvertBooleanNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="logic",
inputs=[io.Boolean.Input("boolean")],

View File

@ -78,6 +78,7 @@ class LoraSave(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LoraSave",
search_aliases=["export lora"],
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[

View File

@ -79,6 +79,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeLumina2",
search_aliases=["lumina prompt"],
display_name="CLIP Text Encode for Lumina2",
category="conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "

View File

@ -50,6 +50,7 @@ class LatentCompositeMasked(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="LatentCompositeMasked",
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
category="latent",
inputs=[
IO.Latent.Input("destination"),
@ -78,6 +79,7 @@ class ImageCompositeMasked(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageCompositeMasked",
search_aliases=["paste image", "overlay", "layer"],
category="image",
inputs=[
IO.Image.Input("destination"),
@ -105,6 +107,7 @@ class MaskToImage(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="MaskToImage",
search_aliases=["convert mask"],
display_name="Convert Mask to Image",
category="mask",
inputs=[
@ -126,6 +129,7 @@ class ImageToMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageToMask",
search_aliases=["extract channel", "channel to mask"],
display_name="Convert Image to Mask",
category="mask",
inputs=[
@ -149,6 +153,7 @@ class ImageColorToMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ImageColorToMask",
search_aliases=["color keying", "chroma key"],
category="mask",
inputs=[
IO.Image.Input("image"),
@ -194,6 +199,7 @@ class InvertMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="InvertMask",
search_aliases=["reverse mask", "flip mask"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -214,6 +220,7 @@ class CropMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="CropMask",
search_aliases=["cut mask", "extract mask region", "mask slice"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -239,6 +246,7 @@ class MaskComposite(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="MaskComposite",
search_aliases=["combine masks", "blend masks", "layer masks"],
category="mask",
inputs=[
IO.Mask.Input("destination"),
@ -287,6 +295,7 @@ class FeatherMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="FeatherMask",
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -333,6 +342,7 @@ class GrowMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="GrowMask",
search_aliases=["expand mask", "shrink mask"],
display_name="Grow Mask",
category="mask",
inputs=[
@ -370,6 +380,7 @@ class ThresholdMask(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ThresholdMask",
search_aliases=["binary mask"],
category="mask",
inputs=[
IO.Mask.Input("mask"),
@ -394,6 +405,7 @@ class MaskPreview(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="MaskPreview",
search_aliases=["show mask", "view mask", "inspect mask", "debug mask"],
display_name="Preview Mask",
category="mask",
description="Saves the input images to your ComfyUI output directory.",

View File

@ -299,6 +299,7 @@ class RescaleCFG:
return (m, )
class ModelComputeDtype:
SEARCH_ALIASES = ["model precision", "change dtype"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),

View File

@ -91,6 +91,7 @@ class CLIPMergeSimple:
class CLIPSubtract:
SEARCH_ALIASES = ["clip difference", "text encoder subtract"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
@ -113,6 +114,7 @@ class CLIPSubtract:
class CLIPAdd:
SEARCH_ALIASES = ["combine clip"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
@ -225,6 +227,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys)
class CheckpointSave:
SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -337,6 +340,7 @@ class VAESave:
return {}
class ModelSave:
SEARCH_ALIASES = ["export model", "checkpoint save"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()

View File

@ -7,6 +7,7 @@ import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
import comfy.ldm.lumina.controlnet
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
class BlockWiseControlBlock(torch.nn.Module):
@ -257,6 +258,14 @@ class ModelPatchLoader:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
elif "audio_proj.proj1.weight" in sd:
model = MultiTalkModelPatch(
audio_window=5, context_tokens=32, vae_scale=4,
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
out_dim=sd["audio_proj.norm.weight"].shape[0],
device=comfy.model_management.unet_offload_device(),
operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@ -524,6 +533,38 @@ class USOStyleReference:
return (model_patched,)
class MultiTalkModelPatch(torch.nn.Module):
def __init__(
self,
audio_window: int = 5,
intermediate_dim: int = 512,
in_dim: int = 5120,
out_dim: int = 768,
context_tokens: int = 32,
vae_scale: int = 4,
num_layers: int = 40,
device=None, dtype=None, operations=None
):
super().__init__()
self.audio_proj = MultiTalkAudioProjModel(
seq_len=audio_window,
seq_len_vf=audio_window+vae_scale-1,
intermediate_dim=intermediate_dim,
out_dim=out_dim,
context_tokens=context_tokens,
device=device,
dtype=dtype,
operations=operations
)
self.blocks = torch.nn.ModuleList(
[
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
]
)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,

View File

@ -12,6 +12,7 @@ class Morphology(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Morphology",
search_aliases=["erode", "dilate"],
display_name="ImageMorphology",
category="image/postprocessing",
inputs=[
@ -57,6 +58,7 @@ class ImageRGBToYUV(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageRGBToYUV",
search_aliases=["color space conversion"],
category="image/batch",
inputs=[
io.Image.Input("image"),
@ -78,6 +80,7 @@ class ImageYUVToRGB(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageYUVToRGB",
search_aliases=["color space conversion"],
category="image/batch",
inputs=[
io.Image.Input("Y"),

View File

@ -7,6 +7,7 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodePixArtAlpha",
search_aliases=["pixart prompt"],
category="advanced/conditioning",
description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.",
inputs=[

View File

@ -402,7 +402,6 @@ def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: st
return input[:, y0:y1, x0:x1]
class ResizeImageMaskNode(io.ComfyNode):
scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"]
@ -421,46 +420,62 @@ class ResizeImageMaskNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("input_type", [io.Image, io.Mask])
crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center")
crop_combo = io.Combo.Input(
"crop",
options=cls.crop_methods,
default="center",
tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.",
)
return io.Schema(
node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask",
description="Resize an image or mask using various scaling methods.",
category="transform",
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
inputs=[
io.MatchType.Input("input", template=template),
io.DynamicCombo.Input("resize_type", options=[
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01),
io.DynamicCombo.Input(
"resize_type",
tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.",
options=[
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."),
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
crop_combo,
io.DynamicCombo.Option(ResizeType.SCALE_BY, [
io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [
io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [
io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [
io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1),
io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [
io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."),
]),
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."),
]),
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
io.MultiType.Input("match", [io.Image, io.Mask]),
crop_combo,
io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [
io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."),
crop_combo,
]),
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1),
io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."),
]),
]),
io.Combo.Input("scale_method", options=cls.scale_methods, default="area"),
],
),
io.Combo.Input(
"scale_method",
options=cls.scale_methods,
default="area",
tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.",
),
],
outputs=[io.MatchType.Output(template=template, display_name="resized")]
)
@ -550,6 +565,7 @@ class BatchImagesNode(io.ComfyNode):
node_id="BatchImagesNode",
display_name="Batch Images",
category="image",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[
io.Autogrow.Input("images", template=autogrow_template)
],
@ -568,6 +584,7 @@ class BatchMasksNode(io.ComfyNode):
autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50)
return io.Schema(
node_id="BatchMasksNode",
search_aliases=["combine masks", "stack masks", "merge masks"],
display_name="Batch Masks",
category="mask",
inputs=[
@ -588,6 +605,7 @@ class BatchLatentsNode(io.ComfyNode):
autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50)
return io.Schema(
node_id="BatchLatentsNode",
search_aliases=["combine latents", "stack latents", "merge latents"],
display_name="Batch Latents",
category="latent",
inputs=[
@ -611,6 +629,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
prefix="input", min=1, max=50)
return io.Schema(
node_id="BatchImagesMasksLatentsNode",
search_aliases=["combine batch", "merge batch", "stack inputs"],
display_name="Batch Images/Masks/Latents",
category="util",
inputs=[

View File

@ -16,6 +16,7 @@ class PreviewAny():
OUTPUT_NODE = True
CATEGORY = "utils"
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
def main(self, source=None):
value = 'None'

View File

@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode):
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples":latent})
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
generate = execute # TODO: remove
@ -65,6 +65,7 @@ class CLIPTextEncodeSD3(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeSD3",
search_aliases=["sd3 prompt"],
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),

View File

@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode):
node_id="StringConcatenate",
display_name="Concatenate",
category="utils/string",
search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
inputs=[
io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True),
@ -31,6 +32,7 @@ class StringSubstring(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringSubstring",
search_aliases=["extract text", "text portion"],
display_name="Substring",
category="utils/string",
inputs=[
@ -53,6 +55,7 @@ class StringLength(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringLength",
search_aliases=["character count", "text size"],
display_name="Length",
category="utils/string",
inputs=[
@ -73,6 +76,7 @@ class CaseConverter(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CaseConverter",
search_aliases=["text case", "uppercase", "lowercase", "capitalize"],
display_name="Case Converter",
category="utils/string",
inputs=[
@ -105,6 +109,7 @@ class StringTrim(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringTrim",
search_aliases=["clean whitespace", "remove whitespace"],
display_name="Trim",
category="utils/string",
inputs=[
@ -135,6 +140,7 @@ class StringReplace(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringReplace",
search_aliases=["find and replace", "substitute", "swap text"],
display_name="Replace",
category="utils/string",
inputs=[
@ -157,6 +163,7 @@ class StringContains(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringContains",
search_aliases=["text includes", "string includes"],
display_name="Contains",
category="utils/string",
inputs=[
@ -184,6 +191,7 @@ class StringCompare(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StringCompare",
search_aliases=["text match", "string equals", "starts with", "ends with"],
display_name="Compare",
category="utils/string",
inputs=[
@ -219,6 +227,7 @@ class RegexMatch(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexMatch",
search_aliases=["pattern match", "text contains", "string match"],
display_name="Regex Match",
category="utils/string",
inputs=[
@ -259,6 +268,7 @@ class RegexExtract(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexExtract",
search_aliases=["pattern extract", "text parser", "parse text"],
display_name="Regex Extract",
category="utils/string",
inputs=[
@ -333,6 +343,7 @@ class RegexReplace(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RegexReplace",
search_aliases=["pattern replace", "find and replace", "substitution"],
display_name="Regex Replace",
category="utils/string",
description="Find and replace text using regex patterns.",

View File

@ -1101,6 +1101,7 @@ class SaveLoRA(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveLoRA",
search_aliases=["export lora"],
display_name="Save LoRA Weights",
category="loaders",
is_experimental=True,
@ -1144,6 +1145,7 @@ class LossGraphNode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LossGraphNode",
search_aliases=["training chart", "training visualization", "plot loss"],
display_name="Plot Loss Graph",
category="training",
is_experimental=True,

View File

@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
node_id="ImageUpscaleWithModel",
display_name="Upscale Image (using Model)",
category="image/upscaling",
search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"],
inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),

View File

@ -16,6 +16,7 @@ class SaveWEBM(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveWEBM",
search_aliases=["export webm"],
category="image/video",
is_experimental=True,
inputs=[
@ -69,6 +70,7 @@ class SaveVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SaveVideo",
search_aliases=["export video"],
display_name="Save Video",
category="image/video",
description="Saves the input images to your ComfyUI output directory.",
@ -116,6 +118,7 @@ class CreateVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="CreateVideo",
search_aliases=["images to video"],
display_name="Create Video",
category="image/video",
description="Create a video from images.",
@ -140,6 +143,7 @@ class GetVideoComponents(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="GetVideoComponents",
search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components",
category="image/video",
description="Extracts all components from a video: frames, audio, and framerate.",
@ -167,6 +171,7 @@ class LoadVideo(io.ComfyNode):
files = folder_paths.filter_files_content_types(files, ["video"])
return io.Schema(
node_id="LoadVideo",
search_aliases=["import video", "open video", "video file"],
display_name="Load Video",
category="image/video",
inputs=[

View File

@ -8,9 +8,10 @@ import comfy.latent_formats
import comfy.clip_vision
import json
import numpy as np
from typing import Tuple
from typing import Tuple, TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import logging
class WanImageToVideo(io.ComfyNode):
@classmethod
@ -286,6 +287,7 @@ class WanVaceToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanVaceToVideo",
search_aliases=["video conditioning", "video control"],
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
@ -704,6 +706,7 @@ class WanTrackToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanTrackToVideo",
search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"],
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
@ -1288,6 +1291,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
return io.NodeOutput(out_latent)
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
class WanInfiniteTalkToVideo(io.ComfyNode):
class DCValues(TypedDict):
mode: str
audio_encoder_output_2: io.AudioEncoderOutput.Type
mask: io.Mask.Type
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanInfiniteTalkToVideo",
category="conditioning/video_models",
inputs=[
io.DynamicCombo.Input("mode", options=[
io.DynamicCombo.Option("single_speaker", []),
io.DynamicCombo.Option("two_speakers", [
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
]),
]),
io.Model.Input("model"),
io.ModelPatch.Input("model_patch"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
io.Image.Input("previous_frames", optional=True),
],
outputs=[
io.Model.Output(display_name="model"),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_image"),
],
)
@classmethod
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
raise ValueError("Not enough previous frames provided.")
if mode["mode"] == "two_speakers":
audio_encoder_output_2 = mode["audio_encoder_output_2"]
mask_1 = mode["mask_1"]
mask_2 = mode["mask_2"]
if audio_encoder_output_2 is not None:
if mask_1 is None or mask_2 is None:
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
ref_masks = None
if mask_1 is not None and mask_2 is not None:
if audio_encoder_output_2 is None:
raise ValueError("Second audio encoder output must be provided if two masks are used.")
ref_masks = torch.cat([mask_1, mask_2])
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
model_patched = model.clone()
encoded_audio_list = []
seq_lengths = []
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
if audio_encoder_output is None:
continue
all_layers = audio_encoder_output["encoded_audio_all_layers"]
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
encoded_audio_list.append(encoded_audio)
seq_lengths.append(encoded_audio.shape[0])
# Pad / combine depending on multi_audio_type
multi_audio_type = "add"
if len(encoded_audio_list) > 1:
if multi_audio_type == "para":
max_len = max(seq_lengths)
padded = []
for emb in encoded_audio_list:
if emb.shape[0] < max_len:
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
emb = torch.cat([emb, pad], dim=0)
padded.append(emb)
encoded_audio_list = padded
elif multi_audio_type == "add":
total_len = sum(seq_lengths)
full_list = []
offset = 0
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
full[offset:offset+seq_len] = emb
full_list.append(full)
offset += seq_len
encoded_audio_list = full_list
token_ref_target_masks = None
if ref_masks is not None:
token_ref_target_masks = torch.nn.functional.interpolate(
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
# when extending from previous frames
if previous_frames is not None:
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
frame_offset = previous_frames.shape[0] - motion_frame_count
audio_start = frame_offset
audio_end = audio_start + length
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
trim_image = motion_frame_count
else:
audio_start = trim_image = 0
audio_end = length
motion_frames_latent = concat_latent_image[:, :, :1]
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
# add outer sample wrapper
model_patched.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
"infinite_talk_outer_sample",
InfiniteTalkOuterSampleWrapper(
motion_frames_latent,
model_patch,
is_extend=previous_frames is not None,
))
# add cross-attention patch
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
if token_ref_target_masks is not None:
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -1307,6 +1475,7 @@ class WanExtension(ComfyExtension):
WanHuMoImageToVideo,
WanAnimateToVideo,
Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo,
]
async def comfy_entrypoint() -> WanExtension:

View File

@ -324,6 +324,7 @@ class GenerateTracks(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="GenerateTracks",
search_aliases=["motion paths", "camera movement", "trajectory"],
category="conditioning/video_models",
inputs=[
io.Int.Input("width", default=832, min=16, max=4096, step=16),

View File

@ -5,6 +5,7 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION
class WebcamCapture(nodes.LoadImage):
SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"]
@classmethod
def INPUT_TYPES(s):
return {

View File

@ -11,7 +11,7 @@ import logging
default_preview_method = args.preview_method
MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
def preview_to_image(latent_image, do_scale=True):
if do_scale:

View File

@ -70,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC):
CATEGORY = "conditioning"
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"]
def encode(self, clip, text):
if clip is None:
@ -86,11 +87,14 @@ class ConditioningCombine:
FUNCTION = "combine"
CATEGORY = "conditioning"
SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"]
def combine(self, conditioning_1, conditioning_2):
return (conditioning_1 + conditioning_2, )
class ConditioningAverage :
SEARCH_ALIASES = ["blend prompts", "interpolate conditioning", "mix prompts", "style fusion", "weighted blend"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
@ -157,6 +161,8 @@ class ConditioningConcat:
return (out, )
class ConditioningSetArea:
SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -215,6 +221,8 @@ class ConditioningSetAreaStrength:
class ConditioningSetMask:
SEARCH_ALIASES = ["masked prompt", "regional inpaint conditioning", "mask conditioning"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -240,6 +248,8 @@ class ConditioningSetMask:
return (c, )
class ConditioningZeroOut:
SEARCH_ALIASES = ["null conditioning", "clear conditioning"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", )}}
@ -294,6 +304,7 @@ class VAEDecode:
CATEGORY = "latent"
DESCRIPTION = "Decodes latent images back into pixel space images."
SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"]
def decode(self, vae, samples):
latent = samples["samples"]
@ -346,6 +357,7 @@ class VAEEncode:
FUNCTION = "encode"
CATEGORY = "latent"
SEARCH_ALIASES = ["encode", "encode image", "image to latent"]
def encode(self, vae, pixels):
t = vae.encode(pixels)
@ -463,6 +475,8 @@ class InpaintModelConditioning:
class SaveLatent:
SEARCH_ALIASES = ["export latent"]
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -514,6 +528,8 @@ class SaveLatent:
class LoadLatent:
SEARCH_ALIASES = ["import latent", "open latent"]
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
@ -550,6 +566,8 @@ class LoadLatent:
class CheckpointLoader:
SEARCH_ALIASES = ["load model", "model loader"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
@ -581,6 +599,7 @@ class CheckpointLoaderSimple:
CATEGORY = "loaders"
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"]
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
@ -588,6 +607,8 @@ class CheckpointLoaderSimple:
return out[:3]
class DiffusersLoader:
SEARCH_ALIASES = ["load diffusers model"]
@classmethod
def INPUT_TYPES(cls):
paths = []
@ -667,6 +688,7 @@ class LoraLoader:
CATEGORY = "loaders"
DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together."
SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"]
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
@ -701,7 +723,7 @@ class LoraLoaderModelOnly(LoraLoader):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@staticmethod
def vae_list(s):
@ -814,6 +836,7 @@ class ControlNetLoader:
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"]
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
@ -890,6 +913,7 @@ class ControlNetApplyAdvanced:
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning/controlnet"
SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"]
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0:
@ -1055,6 +1079,8 @@ class StyleModelLoader:
class StyleModelApply:
SEARCH_ALIASES = ["style transfer"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -1200,13 +1226,16 @@ class EmptyLatentImage:
CATEGORY = "latent"
DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling."
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
class LatentFromBatch:
SEARCH_ALIASES = ["select from batch", "pick latent", "batch subset"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1239,6 +1268,8 @@ class LatentFromBatch:
return (s,)
class RepeatLatentBatch:
SEARCH_ALIASES = ["duplicate latent", "clone latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1265,6 +1296,8 @@ class RepeatLatentBatch:
return (s,)
class LatentUpscale:
SEARCH_ALIASES = ["enlarge latent", "resize latent"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
crop_methods = ["disabled", "center"]
@ -1299,6 +1332,8 @@ class LatentUpscale:
return (s,)
class LatentUpscaleBy:
SEARCH_ALIASES = ["enlarge latent", "resize latent", "scale latent"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
@classmethod
@ -1342,6 +1377,8 @@ class LatentRotate:
return (s,)
class LatentFlip:
SEARCH_ALIASES = ["mirror latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1362,6 +1399,8 @@ class LatentFlip:
return (s,)
class LatentComposite:
SEARCH_ALIASES = ["overlay latent", "layer latent", "paste latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples_to": ("LATENT",),
@ -1404,6 +1443,8 @@ class LatentComposite:
return (samples_out,)
class LatentBlend:
SEARCH_ALIASES = ["mix latents", "interpolate latents"]
@classmethod
def INPUT_TYPES(s):
return {"required": {
@ -1445,6 +1486,8 @@ class LatentBlend:
raise ValueError(f"Unsupported blend mode: {mode}")
class LatentCrop:
SEARCH_ALIASES = ["trim latent", "cut latent"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -1495,7 +1538,7 @@ class SetLatentNoiseMask:
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
@ -1513,6 +1556,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
return (out, )
@ -1540,6 +1584,7 @@ class KSampler:
CATEGORY = "sampling"
DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image."
SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"]
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
@ -1604,6 +1649,7 @@ class SaveImage:
CATEGORY = "image"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"]
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
@ -1611,6 +1657,7 @@ class SaveImage:
results = list()
for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy()
i = np.nan_to_num(i, nan=0.0, posinf=255.0, neginf=0.0)
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
if not args.disable_metadata:
@ -1640,6 +1687,8 @@ class PreviewImage(SaveImage):
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"]
@classmethod
def INPUT_TYPES(s):
return {"required":
@ -1658,6 +1707,7 @@ class LoadImage:
}
CATEGORY = "image"
SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"]
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
@ -1725,6 +1775,8 @@ class LoadImage:
return True
class LoadImageMask:
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod
def INPUT_TYPES(s):
@ -1775,6 +1827,8 @@ class LoadImageMask:
class LoadImageOutput(LoadImage):
SEARCH_ALIASES = ["output image", "previous generation"]
@classmethod
def INPUT_TYPES(s):
return {
@ -1810,6 +1864,7 @@ class ImageScale:
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"]
def upscale(self, image, upscale_method, width, height, crop):
if width == 0 and height == 0:
@ -1847,6 +1902,7 @@ class ImageScaleBy:
return (s,)
class ImageInvert:
SEARCH_ALIASES = ["reverse colors"]
@classmethod
def INPUT_TYPES(s):
@ -1862,6 +1918,7 @@ class ImageInvert:
return (s,)
class ImageBatch:
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
@classmethod
def INPUT_TYPES(s):
@ -1907,6 +1964,7 @@ class EmptyImage:
return (torch.cat((r, g, b), dim=-1), )
class ImagePadForOutpaint:
SEARCH_ALIASES = ["extend canvas", "expand image"]
@classmethod
def INPUT_TYPES(s):

View File

@ -682,6 +682,8 @@ class PromptServer():
if hasattr(obj_class, 'API_NODE'):
info['api_node'] = obj_class.API_NODE
info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', [])
return info
@routes.get("/object_info")