Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2025-08-26 13:59:45 -07:00
commit 443bb45eaf
25 changed files with 986 additions and 16 deletions

View File

@ -38,7 +38,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
@ -50,7 +49,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.51"
__version__ = "0.3.52"

View File

@ -0,0 +1,42 @@
from .wav2vec2 import Wav2Vec2Model
import comfy.model_management
import comfy.ops
import comfy.utils
import logging
import torchaudio
class AudioEncoderModel():
def __init__(self, config):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()
def encode_audio(self, audio, sample_rate):
comfy.model_management.load_model_gpu(self.patcher)
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
out, all_layers = self.model(audio.to(self.load_device))
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
return outputs
def load_audio_encoder_from_sd(sd, prefix=""):
audio_encoder = AudioEncoderModel(None)
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
return audio_encoder

View File

@ -0,0 +1,207 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention_masked
class LayerNormConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
self.layer_norm = operations.LayerNorm(out_channels, elementwise_affine=True, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
class ConvFeatureEncoder(nn.Module):
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
])
def forward(self, x):
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x.transpose(1, 2)
class FeatureProjection(nn.Module):
def __init__(self, conv_dim, embed_dim, dtype=None, device=None, operations=None):
super().__init__()
self.layer_norm = operations.LayerNorm(conv_dim, eps=1e-05, device=device, dtype=dtype)
self.projection = operations.Linear(conv_dim, embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.layer_norm(x)
x = self.projection(x)
return x
class PositionalConvEmbedding(nn.Module):
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
super().__init__()
self.conv = nn.Conv1d(
embed_dim,
embed_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.activation = nn.GELU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv(x)[:, :, :-1]
x = self.activation(x)
x = x.transpose(1, 2)
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
num_layers=12,
mlp_ratio=4.0,
dtype=None, device=None, operations=None
):
super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
self.layers = nn.ModuleList([
TransformerEncoderLayer(
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_layers)
])
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
def forward(self, x, mask=None):
x = x + self.pos_conv_embed(x)
all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x, mask)
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x, mask=None):
assert (mask is None) # TODO?
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention_masked(q, k, v, self.num_heads)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, embed_dim, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__()
self.intermediate_dense = operations.Linear(embed_dim, int(embed_dim * mlp_ratio), device=device, dtype=dtype)
self.output_dense = operations.Linear(int(embed_dim * mlp_ratio), embed_dim, device=device, dtype=dtype)
def forward(self, x):
x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x)
x = self.output_dense(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
embed_dim=768,
num_heads=12,
mlp_ratio=4.0,
dtype=None, device=None, operations=None
):
super().__init__()
self.attention = Attention(embed_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
def forward(self, x, mask=None):
residual = x
x = self.layer_norm(x)
x = self.attention(x, mask=mask)
x = residual + x
x = x + self.feed_forward(self.final_layer_norm(x))
return x
class Wav2Vec2Model(nn.Module):
"""Complete Wav2Vec 2.0 model."""
def __init__(
self,
embed_dim=1024,
final_dim=256,
num_heads=16,
num_layers=24,
dtype=None, device=None, operations=None
):
super().__init__()
conv_dim = 512
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
self.encoder = TransformerEncoder(
embed_dim=embed_dim,
num_heads=num_heads,
num_layers=num_layers,
device=device, dtype=dtype, operations=operations
)
def forward(self, x, mask_time_indices=None, return_dict=False):
x = torch.mean(x, dim=1)
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
features = self.feature_extractor(x)
features = self.feature_projection(features)
batch_size, seq_len, _ = features.shape
x, all_x = self.encoder(features)
return x, all_x

View File

@ -110,6 +110,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
ModelPaths(["classifiers"], supported_extensions=set()),
ModelPaths(["huggingface"], supported_extensions=set()),
ModelPaths(["model_patches"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["audio_encoders"], supported_extensions=set(supported_pt_extensions)),
hf_cache_paths,
hf_xet,
]

View File

@ -19,6 +19,7 @@ import torch
from torch import nn
from ... import model_management
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from ..lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
@ -343,7 +344,28 @@ class ACEStepTransformer2DModel(nn.Module):
output = self.final_layer(hidden_states, embedded_timestep, output_length)
return output
def forward(
def forward(self,
x,
timestep,
attention_mask=None,
context: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.LongTensor] = None,
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timestep, attention_mask, context, text_attention_mask, speaker_embeds, lyric_token_idx, lyric_mask, block_controlnet_hidden_states,
controlnet_scale, lyrics_strength, **kwargs)
def _forward(
self,
x,
timestep,

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
from ..modules.attention import optimized_attention
from ... import ops
from .. import common_dit
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@ -436,6 +437,13 @@ class MMDiT(nn.Module):
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, transformer_options, **kwargs)
def _forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape

View File

@ -6,6 +6,7 @@ import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from ..common_dit import pad_to_patch_size
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from ..flux.layers import EmbedND, timestep_embedding
@ -250,6 +251,13 @@ class Chroma(nn.Module):
return img
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
x = pad_to_patch_size(x, (self.patch_size, self.patch_size))

View File

@ -27,6 +27,8 @@ from torchvision import transforms
from enum import Enum
import logging
import comfy.patcher_extension
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
@ -435,6 +437,42 @@ class GeneralDIT(nn.Module):
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x,
timesteps,
context,
attention_mask,
fps,
image_size,
padding_mask,
scalar_feature,
data_type,
latent_condition,
latent_condition_sigma,
condition_video_augment_sigma,
**kwargs)
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# crossattn_emb: torch.Tensor,
# crossattn_mask: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
image_size: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
scalar_feature: Optional[torch.Tensor] = None,
data_type: Optional[DataType] = DataType.VIDEO,
latent_condition: Optional[torch.Tensor] = None,
latent_condition_sigma: Optional[torch.Tensor] = None,
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Args:

View File

@ -11,6 +11,7 @@ import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from ..modules.attention import optimized_attention
def apply_rotary_pos_emb(
@ -805,7 +806,21 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
def forward(
def forward(self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, fps, padding_mask, **kwargs)
def _forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,

View File

@ -4,6 +4,7 @@ import torch
from dataclasses import dataclass
from einops import rearrange, repeat
from torch import Tensor, nn
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from .layers import (
DoubleStreamBlock,
@ -41,6 +42,7 @@ class Flux(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
# todo: should this be here?
self.device = device
self.dtype = dtype
params = FluxParams(**kwargs)
@ -215,6 +217,13 @@ class Flux(nn.Module):
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, ref_latents, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
bs, c, h_orig, w_orig = x.shape
patch_size = self.patch_size

View File

@ -14,6 +14,7 @@ from ..flux.layers import LastLayer
from ..modules.attention import optimized_attention
from ...model_management import cast_to
from ..common_dit import pad_to_patch_size
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
@ -692,7 +693,23 @@ class HiDreamImageTransformer2DModel(nn.Module):
raise NotImplementedError
return x, x_masks, img_sizes
def forward(
def forward(self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
encoder_hidden_states_llama3=None,
image_cond=None,
control = None,
transformer_options = {},
):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, y, context, encoder_hidden_states_llama3, image_cond, control, transformer_options)
def _forward(
self,
x: torch.Tensor,
t: torch.Tensor,

View File

@ -1,6 +1,7 @@
import torch
from torch import nn
from ..flux.layers import DoubleStreamBlock, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
class Hunyuan3Dv2(nn.Module):
@ -61,6 +62,13 @@ class Hunyuan3Dv2(nn.Module):
self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, guidance, transformer_options, **kwargs)
def _forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
x = x.movedim(-1, -2)
timestep = 1.0 - timestep
txt = context

View File

@ -4,6 +4,7 @@ import torch
from ..modules.attention import optimized_attention
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from dataclasses import dataclass
from einops import repeat
@ -339,6 +340,13 @@ class HunyuanVideo(nn.Module):
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)

View File

@ -8,6 +8,7 @@ from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from ..modules.attention import optimized_attention, optimized_attention_masked
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
def get_timestep_embedding(
@ -421,6 +422,13 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)

View File

@ -11,6 +11,7 @@ from ..common_dit import pad_to_patch_size
from ..modules.diffusionmodules.mmdit import TimestepEmbedder
from ..modules.attention import optimized_attention_masked
from ..flux.layers import EmbedND
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
def modulate(x, scale):
@ -590,8 +591,15 @@ class NextDiT(nn.Module):
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
# def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask

View File

@ -9,7 +9,7 @@ from ..common_dit import pad_to_patch_size
from ..flux.layers import EmbedND
from ..lightricks.model import TimestepEmbedding, Timesteps
from ..modules.attention import optimized_attention_no_sage_masked as optimized_attention_masked
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@ -355,7 +355,14 @@ class QwenImageTransformer2DModel(nn.Module):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def _forward(
self,
x,
timesteps,

View File

@ -11,6 +11,7 @@ from ..flux.layers import EmbedND
from ..flux.math import apply_rope
from ..common_dit import pad_to_patch_size
from ...model_management import cast_to
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
def sinusoidal_embedding_1d(dim, position):
@ -573,6 +574,13 @@ class WanModel(torch.nn.Module):
return x
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, clip_fea, time_dim_concat, transformer_options, **kwargs)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = pad_to_patch_size(x, self.patch_size)

View File

@ -50,6 +50,7 @@ class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
PREDICT_NOISE = "predict_noise"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"

View File

@ -24,6 +24,7 @@ from .model_management_types import ModelOptions
from .model_patcher import ModelPatcher
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
from .context_windows import ContextHandlerABC
from .utils import common_upscale
logger = logging.getLogger(__name__)
@ -69,7 +70,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert (mask.shape[1:] == x_in.shape[2:])
# assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@ -77,7 +78,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
mask = mask * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1,) * (mask.ndim - 1))
else:
mask = torch.ones_like(input_x)
mult = mask * strength
@ -586,7 +587,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
if len(mask.shape) == len(dims):
mask = mask.unsqueeze(0)
if mask.shape[1:] != dims:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
if mask.ndim < 4:
mask = common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
else:
mask = common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
if modified.get("set_area_to_bounds", False): # TODO: handle dim != 2
bounds = torch.max(torch.abs(mask), dim=0).values.unsqueeze(0)
@ -991,7 +995,14 @@ class CFGGuider:
self.original_conds[k] = sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
return self.outer_predict_noise(*args, **kwargs)
def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.predict_noise,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, self.model_options, is_model_options=True)
).execute(x, timestep, model_options, seed)
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)

View File

@ -730,6 +730,14 @@ class AnyType(ComfyTypeIO):
class MODEL_PATCH(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER")
class AudioEncoder(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO_ENCODER_OUTPUT")
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -1586,6 +1594,7 @@ class _IO:
Model = Model
ClipVision = ClipVision
ClipVisionOutput = ClipVisionOutput
AudioEncoderOutput = AudioEncoderOutput
StyleModel = StyleModel
Gligen = Gligen
UpscaleModel = UpscaleModel

View File

@ -0,0 +1,44 @@
import folder_paths
import comfy.audio_encoders.audio_encoders
import comfy.utils
class AudioEncoderLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder_name": (folder_paths.get_filename_list("audio_encoders"), ),
}}
RETURN_TYPES = ("AUDIO_ENCODER",)
FUNCTION = "load_model"
CATEGORY = "loaders"
def load_model(self, audio_encoder_name):
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
if audio_encoder is None:
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
return (audio_encoder,)
class AudioEncoderEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio_encoder": ("AUDIO_ENCODER",),
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO_ENCODER_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, audio_encoder, audio):
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
return (output,)
NODE_CLASS_MAPPINGS = {
"AudioEncoderLoader": AudioEncoderLoader,
"AudioEncoderEncode": AudioEncoderEncode,
}

View File

@ -0,0 +1,493 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from comfy_api.latest import io, ComfyExtension
import comfy.patcher_extension
import logging
import torch
import comfy.model_patcher
if TYPE_CHECKING:
from uuid import UUID
def easycache_forward_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
transformer_options: dict[str] = args[-1]
if not isinstance(transformer_options, dict):
transformer_options = kwargs.get("transformer_options")
if not transformer_options:
transformer_options = args[-2]
easycache: EasyCacheHolder = transformer_options["easycache"]
sigmas = transformer_options["sigmas"]
uuids = transformer_options["uuids"]
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
return executor(*args, **kwargs)
# prepare next x_prev
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(sigmas)
if do_easycache:
easycache.check_metadata(x)
# if first cond marked this step for skipping, skip it and use appropriate cached values
if easycache.skip_current_step:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
return easycache.apply_cache_diff(x, uuids)
if easycache.initial_step:
easycache.first_cond_uuid = uuids[0]
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
easycache.initial_step = False
if has_first_cond_uuid:
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
return easycache.apply_cache_diff(x, uuids)
else:
if easycache.verbose:
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
if has_first_cond_uuid and easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
output_change_rate = output_change / easycache.output_prev_norm
easycache.output_change_rates.append(output_change_rate.item())
if easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
if easycache.verbose:
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
if input_change is not None:
easycache.relative_transformation_rate = output_change / input_change
if easycache.verbose:
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev, uuids)
if has_first_cond_uuid:
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return output
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
# get values from args
x: torch.Tensor = args[0]
timestep: float = args[1]
model_options: dict[str] = args[2]
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
# prepare next x_prev
next_x_prev = x
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
if do_easycache:
easycache.check_metadata(x)
if easycache.has_x_prev_subsampled():
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.cumulative_change_rate += approx_output_change_rate
if easycache.cumulative_change_rate < easycache.reuse_threshold:
if easycache.verbose:
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
# other conds should also skip this step, and instead use their cached values
easycache.skip_current_step = True
return easycache.apply_cache_diff(x)
else:
if easycache.verbose:
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
easycache.cumulative_change_rate = 0.0
output: torch.Tensor = executor(*args, **kwargs)
if easycache.has_output_prev_norm():
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
if easycache.verbose:
output_change_rate = output_change / easycache.output_prev_norm
easycache.output_change_rates.append(output_change_rate.item())
if easycache.has_relative_transformation_rate():
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
if easycache.verbose:
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
if input_change is not None:
easycache.relative_transformation_rate = output_change / input_change
if easycache.verbose:
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
# TODO: allow cache_diff to be offloaded
easycache.update_cache_diff(output, next_x_prev)
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
easycache.output_prev_subsampled = easycache.subsample(output)
easycache.output_prev_norm = output.flatten().abs().mean()
if easycache.verbose:
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
return output
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
model_options = args[-1]
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
easycache.skip_current_step = False
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
return executor(*args, **kwargs)
def easycache_sample_wrapper(executor, *args, **kwargs):
"""
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
"""
try:
guider = executor.class_obj
orig_model_options = guider.model_options
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
# clone and prepare timesteps
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
return executor(*args, **kwargs)
finally:
easycache = guider.model_options['transformer_options']['easycache']
output_change_rates = easycache.output_change_rates
approx_output_change_rates = easycache.approx_output_change_rates
if easycache.verbose:
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
easycache.reset()
guider.model_options = orig_model_options
class EasyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
self.name = "EasyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.offload_cache_diff = offload_cache_diff
self.verbose = verbose
# timestep values
self.start_t = 0.0
self.end_t = 0.0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.initial_step = True
self.skip_current_step = False
# cache values
self.first_cond_uuid = None
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
# how to deal with mismatched dims
self.allow_mismatch = True
self.cut_from_start = True
self.state_metadata = None
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
def should_do_easycache(self, timestep: float) -> bool:
return (timestep[0] <= self.start_t).item()
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
batch_offset = x.shape[0] // len(uuids)
uuid_idx = uuids.index(self.first_cond_uuid)
if self.subsample_factor > 1:
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
if clone:
return to_return.clone()
return to_return
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
if self.first_cond_uuid in uuids:
self.total_steps_skipped += 1
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_this_dim = True
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
if skip_this_dim:
skip_this_dim = False
continue
if dim_u != dim_x:
if self.cut_from_start:
slicing.append(slice(dim_x-dim_u, None))
else:
slicing.append(slice(None, dim_u))
else:
slicing.append(slice(None))
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
x = x[slicing]
x += self.uuid_cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if output.shape[1:] != x.shape[1:]:
if not self.allow_mismatch:
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
slicing = []
skip_dim = True
for dim_o, dim_x in zip(output.shape, x.shape):
if not skip_dim and dim_o != dim_x:
if self.cut_from_start:
slicing.append(slice(dim_x-dim_o, None))
else:
slicing.append(slice(None, dim_o))
else:
slicing.append(slice(None))
skip_dim = False
x = x[slicing]
diff = output - x
batch_offset = diff.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape[1:])
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.initial_step = True
self.skip_current_step = False
self.output_change_rates = []
self.first_cond_uuid = None
del self.x_prev_subsampled
self.x_prev_subsampled = None
del self.output_prev_subsampled
self.output_prev_subsampled = None
del self.output_prev_norm
self.output_prev_norm = None
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self
def clone(self):
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
class EasyCacheNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EasyCache",
display_name="EasyCache",
description="Native EasyCache implementation.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add EasyCache to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with EasyCache."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
return io.NodeOutput(model)
class LazyCacheHolder:
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
self.name = "LazyCache"
self.reuse_threshold = reuse_threshold
self.start_percent = start_percent
self.end_percent = end_percent
self.subsample_factor = subsample_factor
self.offload_cache_diff = offload_cache_diff
self.verbose = verbose
# timestep values
self.start_t = 0.0
self.end_t = 0.0
# control values
self.relative_transformation_rate: float = None
self.cumulative_change_rate = 0.0
self.initial_step = True
# cache values
self.x_prev_subsampled: torch.Tensor = None
self.output_prev_subsampled: torch.Tensor = None
self.output_prev_norm: torch.Tensor = None
self.cache_diff: torch.Tensor = None
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
self.state_metadata = None
def has_cache_diff(self) -> bool:
return self.cache_diff is not None
def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
def should_do_easycache(self, timestep: float) -> bool:
return (timestep[0] <= self.start_t).item()
def has_x_prev_subsampled(self) -> bool:
return self.x_prev_subsampled is not None
def has_output_prev_subsampled(self) -> bool:
return self.output_prev_subsampled is not None
def has_output_prev_norm(self) -> bool:
return self.output_prev_norm is not None
def has_relative_transformation_rate(self) -> bool:
return self.relative_transformation_rate is not None
def prepare_timesteps(self, model_sampling):
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
return self
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
if self.subsample_factor > 1:
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
if clone:
return to_return.clone()
return to_return
if clone:
return x.clone()
return x
def apply_cache_diff(self, x: torch.Tensor):
self.total_steps_skipped += 1
return x + self.cache_diff.to(x.device)
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
self.cache_diff = output - x
def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape)
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False
def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
self.initial_step = True
self.output_change_rates = []
self.approx_output_change_rates = []
del self.cache_diff
self.cache_diff = None
del self.x_prev_subsampled
self.x_prev_subsampled = None
del self.output_prev_subsampled
self.output_prev_subsampled = None
del self.output_prev_norm
self.output_prev_norm = None
self.total_steps_skipped = 0
self.state_metadata = None
return self
def clone(self):
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
class LazyCacheNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LazyCache",
display_name="LazyCache",
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
category="advanced/debug/model",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add LazyCache to."),
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."),
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."),
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."),
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."),
],
outputs=[
io.Model.Output(tooltip="The model with LazyCache."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
model = model.clone()
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
return io.NodeOutput(model)
class EasyCacheExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EasyCacheNode,
LazyCacheNode,
]
def comfy_entrypoint():
return EasyCacheExtension()

View File

@ -1,6 +1,6 @@
[project]
name = "comfyui"
version = "0.3.51"
version = "0.3.52"
description = "An installable version of ComfyUI"
readme = "README.md"
authors = [
@ -18,8 +18,8 @@ classifiers = [
]
dependencies = [
"comfyui-frontend-package>=1.25.9",
"comfyui-workflow-templates>=0.1.65",
"comfyui-frontend-package>=1.25.10",
"comfyui-workflow-templates>=0.1.66",
"comfyui-embedded-docs>=0.2.6",
"torch",
"torchvision",