Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-07-30 11:11:45 -07:00
commit 34522e0914
55 changed files with 48774 additions and 143 deletions

0
.ci/update_windows/update.py Executable file
View File

0
.github/workflows/stable-release.yml vendored Normal file
View File

View File

@ -43,6 +43,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.

View File

@ -14,6 +14,7 @@ from ..ldm.modules.attention import SpatialTransformer, optimized_attention
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
from .. import ops
from .control_types import UNION_CONTROLNET_TYPES
from collections import OrderedDict
@ -392,6 +393,18 @@ class ControlNet(nn.Module):
if self.control_add_embedding is not None: #Union Controlnet
control_type = kwargs.get("control_type", [])
if any([c >= self.num_control_type for c in control_type]):
max_type = max(control_type)
max_type_name = {
v: k for k, v in UNION_CONTROLNET_TYPES.items()
}[max_type]
raise ValueError(
f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
f"({self.num_control_type}) supported.\n" +
"Please consider using the ProMax ControlNet Union model.\n" +
"https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
)
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0:
if len(hint.shape) < 5:

View File

@ -0,0 +1,10 @@
UNION_CONTROLNET_TYPES = {
"openpose": 0,
"depth": 1,
"hed/pidi/scribble/ted": 2,
"canny/lineart/anime_lineart/mlsd": 3,
"normal": 4,
"segment": 5,
"tile": 6,
"repaint": 7,
}

View File

@ -6,6 +6,7 @@ import inspect
import logging
import sys
import threading
import time
import traceback
import typing
from typing import List, Optional, Tuple
@ -463,6 +464,11 @@ class PromptExecutor:
self.status_messages = []
def add_message(self, event, data, broadcast: bool):
data = {
**data,
# todo: use a real time library
"timestamp": int(time.time() * 1000),
}
self.status_messages.append((event, data))
if self.server.client_id is not None or broadcast:
self.server.send_sync(event, data, self.server.client_id)

View File

@ -225,8 +225,8 @@ async def main():
def entrypoint():
try:
asyncio.run(main())
except KeyboardInterrupt as keyboard_interrupt:
logging.info(f"Gracefully shutting down due to {keyboard_interrupt}")
except KeyboardInterrupt:
logging.info(f"Gracefully shutting down due to KeyboardInterrupt")
if __name__ == "__main__":

View File

@ -7,10 +7,13 @@ Use this instead of cli_args to import the args:
It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes.
"""
import importlib.util
import logging
import os
import shutil
import sys
import warnings
import ctypes
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
@ -18,7 +21,7 @@ from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs
from .. import options
@ -46,6 +49,31 @@ if args.deterministic:
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
def _fix_pytorch_240():
"""Fixes pytorch 2.4.0"""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
with open(test_file, 'rb') as f:
contents = f.read()
# todo: dubious
if b"libomp140.x86_64.dll" not in contents:
break
try:
_ = ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError:
logging.warning("Detected pytorch version with libomp issue, trying to patch")
try:
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
except Exception as exc_info:
logging.error("While trying to patch a fix for torch 2.4.0, an error occurred, which means this is unlikely to work", exc_info=exc_info)
def _create_tracer():
resource = Resource.create({
ResAttrs.SERVICE_NAME: args.otel_service_name,
@ -79,5 +107,6 @@ def _create_tracer():
return trace.get_tracer(args.otel_service_name)
_fix_pytorch_240()
tracer = _create_tracer()
__all__ = ["args", "tracer"]

View File

@ -0,0 +1,17 @@
from typing import Annotated
from jaxtyping import Float, Shaped
from torch import Tensor
def channels_constraint(n: int):
def constraint(x: Tensor) -> bool:
return x.shape[-1] == n
return constraint
ImageBatch = Float[Tensor, "batch height width channels"]
RGBImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(3)]] | Float[Tensor, "batch height width 3"]
RGBAImageBatch = Annotated[ImageBatch, Shaped[channels_constraint(4)]] | Float[Tensor, "batch height width 4"]
RGBImage = Float[Tensor, "height width 3"]

View File

@ -4,8 +4,8 @@ from contextlib import contextmanager
from contextvars import ContextVar
from typing import NamedTuple
from comfy.component_model.executor_types import ExecutorToClientProgress
from comfy.distributed.server_stub import ServerStub
from .component_model.executor_types import ExecutorToClientProgress
from .distributed.server_stub import ServerStub
_current_context = ContextVar("comfyui_execution_context")

View File

@ -0,0 +1,219 @@
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional
from comfy.ldm.modules.attention import optimized_attention
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: Optional[torch.Tensor],
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
if xk is not None:
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
class CrossAttention(nn.Module):
"""
Use QK Normalization.
"""
def __init__(self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
attn_precision=None,
device=None,
dtype=None,
operations=None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.attn_precision = attn_precision
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
self.head_dim = self.qdim // num_heads
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, y, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s1, c = x.shape # [b, s1, D]
_, s2, c = y.shape # [b, s2, 1024]
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
k, v = kv.unbind(dim=2) # [b, s, h, d]
q = self.q_norm(q)
k = self.k_norm(k)
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
q = qq
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
v = v.transpose(-2, -3).contiguous()
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
out = self.out_proj(context) # context.reshape - B, L1, -1
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
class Attention(nn.Module):
"""
We rename some layer names to align with flash attention
"""
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
self.head_dim = self.dim // num_heads
# This assertion is aligned with flash attention
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
# qkv --> Wqkv
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, freqs_cis_img=None):
B, N, C = x.shape
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
q, k, v = qkv.unbind(0) # [b, h, s, d]
q = self.q_norm(q) # [b, h, s, d]
k = self.k_norm(k) # [b, h, s, d]
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
assert qq.shape == q.shape and kk.shape == k.shape, \
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
q, k = qq, kk
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
x = self.out_proj(x)
x = self.proj_drop(x)
out_tuple = (x,)
return out_tuple

404
comfy/ldm/hydit/models.py Normal file
View File

@ -0,0 +1,404 @@
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint
from .attn_layers import Attention, CrossAttention
from .poolers import AttentionPool
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
def calc_rope(x, patch_size, head_size):
th = (x.shape[2] + (patch_size // 2)) // patch_size
tw = (x.shape[3] + (patch_size // 2)) // patch_size
base_size = 512 // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class HunYuanDiTBlock(nn.Module):
"""
A HunYuanDiT block with `add` conditioning.
"""
def __init__(self,
hidden_size,
c_emb_size,
num_heads,
mlp_ratio=4.0,
text_states_dim=1024,
qk_norm=False,
norm_type="layer",
skip=False,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
use_ele_affine = True
if norm_type == "layer":
norm_layer = operations.LayerNorm
elif norm_type == "rms":
norm_layer = RMSNorm
else:
raise ValueError(f"Unknown norm_type: {norm_type}")
# ========================= Self-Attention =========================
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
# ========================= FFN =========================
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
# ========================= Add =========================
# Simply use add like SDXL.
self.default_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
)
# ========================= Cross-Attention =========================
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
# ========================= Skip Connection =========================
if skip:
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
else:
self.skip_linear = None
self.gradient_checkpointing = False
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
# Self-Attention
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
attn_inputs = (
self.norm1(x) + shift_msa, freq_cis_img,
)
x = x + self.attn1(*attn_inputs)[0]
# Cross-Attention
cross_inputs = (
self.norm3(x), text_states, freq_cis_img
)
x = x + self.attn2(*cross_inputs)[0]
# FFN Layer
mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs)
return x
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
if self.gradient_checkpointing and self.training:
return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
return self._forward(x, c, text_states, freq_cis_img, skip)
class FinalLayer(nn.Module):
"""
The final layer of HunYuanDiT.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class HunYuanDiT(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
#@register_to_config
def __init__(self,
input_size: tuple = 32,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1152,
depth: int = 28,
num_heads: int = 16,
mlp_ratio: float = 4.0,
text_states_dim = 1024,
text_states_dim_t5 = 2048,
text_len = 77,
text_len_t5 = 256,
qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
size_cond = False,
use_style_cond = False,
learn_sigma = True,
norm = "layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
#import pdb
#pdb.set_trace()
self.mlp_t5 = nn.Sequential(
operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, hidden_size, dtype=dtype, device=device)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
self.extra_embedder = nn.Sequential(
operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=layer > depth // 2,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for layer in range(depth)
])
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
self.unpatchify_channels = self.out_channels
def forward(self,
x,
t,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
return_dict=False,
control=None,
transformer_options=None,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = self.text_embedding_padding.to(text_states)
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
_, _, oh, ow = x.shape
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(t, dtype=x.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
if self.size_cond:
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
image_meta_size = image_meta_size.view(-1, 6 * 256)
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if self.use_style_cond:
if style is None:
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)
if controls is not None and len(controls) != 0:
raise ValueError("The number of controls is not equal to the number of skip connections.")
# ========================= Final layer =========================
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
if return_dict:
return {'x': x}
if self.learn_sigma:
return x[:,:self.out_channels // 2,:oh,:ow]
return x[:,:,:oh,:ow]
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
p = self.x_embedder.patch_size[0]
# h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@ -0,0 +1,37 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention #TODO
class AttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
self.num_heads = num_heads
self.embed_dim = embed_dim
def forward(self, x):
x = x[:,:self.positional_embedding.shape[0] - 1]
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC
q = self.q_proj(x[:1])
k = self.k_proj(x)
v = self.v_proj(x)
batch_size = q.shape[1]
head_dim = self.embed_dim // self.num_heads
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
attn_output = self.c_proj(attn_output)
return attn_output.squeeze(0)

View File

@ -0,0 +1,224 @@
import torch
import numpy as np
from typing import Union
def _to_tuple(x):
if isinstance(x, int):
return x, x
else:
return x
def get_fill_resize_and_crop(src, tgt):
th, tw = _to_tuple(tgt)
h, w = _to_tuple(src)
tr = th / tw # base resolution
r = h / w # target resolution
# resize
if r > tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_meshgrid(start, *args):
if len(args) == 0:
# start is grid_size
num = _to_tuple(start)
start = (0, 0)
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = (stop[0] - start[0], stop[1] - start[1])
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = _to_tuple(args[1])
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid = get_meshgrid(start, *args) # [2, H, w]
# grid_h = np.arange(grid_size, dtype=np.float32)
# grid_w = np.arange(grid_size, dtype=np.float32)
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (W,H)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
"""
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
Parameters
----------
embed_dim: int
embedding dimension size
start: int or tuple of int
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
use_real: bool
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns
-------
pos_embed: torch.Tensor
[HW, D/2]
"""
grid = get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def calc_sizes(rope_img, patch_size, th, tw):
if rope_img == 'extend':
# Expansion mode
sub_args = [(th, tw)]
elif rope_img.startswith('base'):
# Based on the specified dimensions, other dimensions are obtained through interpolation.
base_size = int(rope_img[4:]) // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
else:
raise ValueError(f"Unknown rope_img: {rope_img}")
return sub_args
def init_image_posemb(rope_img,
resolutions,
patch_size,
hidden_size,
num_heads,
log_fn,
rope_real=True,
):
freqs_cis_img = {}
for reso in resolutions:
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
sub_args = calc_sizes(rope_img, patch_size, th, tw)
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
return freqs_cis_img

View File

@ -8,6 +8,7 @@ import torch
import torch.nn as nn
from .. import attention
from einops import rearrange, repeat
from .util import timestep_embedding
def default(x, y):
if x is not None:
@ -231,34 +232,8 @@ class TimestepEmbedder(nn.Module):
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
if torch.is_floating_point(t):
embedding = embedding.to(dtype=t.dtype)
return embedding
def forward(self, t, dtype, **kwargs):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb

View File

@ -280,4 +280,10 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
if isinstance(model, model_base.HunyuanDiT):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
return key_map

View File

@ -18,6 +18,7 @@ from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
from .ldm.hydit.models import HunYuanDiT
class ModelType(Enum):
EPS = 1
@ -669,3 +670,35 @@ class StableAudio1(BaseModel):
for l in s:
sd["{}{}".format(k, l)] = s[l]
return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['text_embedding_mask'] = conds.CONDRegular(attention_mask)
conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None)
if conditioning_mt5xl is not None:
out['encoder_hidden_states_t5'] = conds.CONDRegular(conditioning_mt5xl)
attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None)
if attention_mask_mt5xl is not None:
out['text_embedding_mask_t5'] = conds.CONDRegular(attention_mask_mt5xl)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out

View File

@ -114,6 +114,19 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["n_layers"] = double_layers + single_layers
return unet_config
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT
unet_config = {}
unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
unet_config["mlp_ratio"] = 4.3637
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
unet_config["size_cond"] = True
unet_config["use_style_cond"] = True
unet_config["image_model"] = "hydit1"
return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -260,13 +273,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config
def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models
unet_key_prefix = "model.model."
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
unet_key_prefix = "model."
candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models
]
counts = {k: 0 for k in candidates}
for k in state_dict:
for c in candidates:
if k.startswith(c):
counts[c] += 1
break
top = max(counts, key=counts.get)
if counts[top] > 5:
return top
else:
unet_key_prefix = "model.diffusion_model."
return unet_key_prefix
return "model." #aura flow and others
def convert_config(unet_config):
new_config = unet_config.copy()

View File

@ -220,7 +220,11 @@ KNOWN_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"),
# stable audio, # uses names from https://comfyanonymous.github.io/ComfyUI_examples/audio/
HuggingFile("stabilityai/stable-audio-open-1.0", "model.safetensors", save_with_filename="stable_audio_open_1.0.safetensors")
HuggingFile("stabilityai/stable-audio-open-1.0", "model.safetensors", save_with_filename="stable_audio_open_1.0.safetensors"),
# hunyuandit
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.0.safetensors"),
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.1.safetensors"),
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.2.safetensors"),
], folder_name="checkpoints")
KNOWN_UNCLIP_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -1,6 +1,6 @@
import hashlib
from comfy.cli_args import args
from .cli_args import args
from PIL import ImageFile, UnidentifiedImageError
@ -25,7 +25,7 @@ def pillow(fn, arg):
finally:
if prev_value is not None:
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
return x
return x
def hasher():
hashfuncs = {

View File

@ -1,5 +1,5 @@
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]

View File

@ -5,6 +5,8 @@ import collections
from . import model_management
import math
import logging
import scipy
import numpy
from . import sampler_helpers
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
@ -314,13 +316,18 @@ def simple_scheduler(model_sampling, steps):
def ddim_scheduler(model_sampling, steps):
s = model_sampling
sigs = []
ss = max(len(s.sigmas) // steps, 1)
x = 1
if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
steps += 1
sigs = []
else:
sigs = [0.0]
ss = max(len(s.sigmas) // steps, 1)
while x < len(s.sigmas):
sigs += [float(s.sigmas[x])]
x += ss
sigs = sigs[::-1]
sigs += [0.0]
return torch.FloatTensor(sigs)
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
@ -328,15 +335,34 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)
append_zero = True
if sgm:
timesteps = torch.linspace(start, end, steps + 1)[:-1]
else:
if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
steps += 1
append_zero = False
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(s.sigma(ts))
sigs.append(float(s.sigma(ts)))
if append_zero:
sigs += [0.0]
return torch.FloatTensor(sigs)
# Implemented based on: https://arxiv.org/abs/2407.12173
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
total_timesteps = (len(model_sampling.sigmas) - 1)
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
sigs = []
for t in ts:
sigs += [float(model_sampling.sigmas[int(t)])]
sigs += [0.0]
return torch.FloatTensor(sigs)
@ -709,7 +735,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = None
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential":
@ -722,7 +748,9 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
elif scheduler_name == "beta":
sigmas = beta_scheduler(model_sampling, steps)
if sigmas is None:
logging.error("error invalid scheduler {}".format(scheduler_name))

View File

@ -17,19 +17,22 @@ from . import model_detection
from . import model_management
from . import model_patcher
from . import model_sampling
from .text_encoders import sa_t5
from . import sd1_clip
from . import sd2_clip
from .text_encoders import sd3_clip
from . import sdxl_clip
from . import utils
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import hydit
from .text_encoders import sa_t5
from .text_encoders import aura_t5
from .ldm.audio.autoencoder import AudioOobleckVAE
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .t2i_adapter import adapter
from .taesd import taesd
from .text_encoders import aura_t5
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
@ -63,7 +66,9 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None):
def __init__(self, target: CLIPTarget=None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None=None):
if tokenizer_data is None:
tokenizer_data = dict()
if no_init:
return
params = target.params.copy()
@ -84,7 +89,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory)
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
@ -141,7 +146,11 @@ class CLIP:
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
sd_clip = self.cond_stage_model.state_dict()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self):
model_management.load_model_gpu(self.patcher)
@ -390,6 +399,7 @@ class CLIPType(Enum):
STABLE_CASCADE = 2
SD3 = 3
STABLE_AUDIO = 4
HUNYUAN_DIT = 5
@dataclasses.dataclass
@ -444,6 +454,9 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = hydit.HyditModel
clip_target.tokenizer = hydit.HyditTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@ -540,7 +553,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))

View File

@ -7,6 +7,8 @@ import numbers
import os
import traceback
import zipfile
from pathlib import Path
try:
from importlib.resources.abc import Traversable # pylint: disable=no-name-in-module
except ImportError:
@ -21,7 +23,7 @@ from . import model_management
from . import ops
from .component_model import files
from .component_model.files import get_path_as_dict, get_package_as_path
from .text_encoders.llama_tokenizer import LLAMATokenizer
from .text_encoders.spiece_tokenizer import SPieceTokenizer
def gen_empty_tokens(special_tokens, length):
@ -256,7 +258,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
o = self.encode(to_encode)
out, pooled = o[:2]
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
@ -275,8 +279,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
output.append(z)
if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
return r
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
@ -448,19 +464,22 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
def __init__(self, tokenizer_path: torch.Tensor | bytes | bytearray | memoryview | str | Path | Traversable = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
if tokenizer_path is None:
tokenizer_path = files.get_package_as_path("comfy.sd1_tokenizer")
if isinstance(tokenizer_path, Path):
tokenizer_path = str(tokenizer_path)
if isinstance(tokenizer_path, Traversable):
contextlib_path = importlib.resources.as_file(tokenizer_path)
tokenizer_path = contextlib_path.__enter__()
tokenizer_path = str(tokenizer_path)
if issubclass(tokenizer_class, CLIPTokenizer) and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# assumes sd1_tokenizer
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
self.tokenizer_class = tokenizer_class
self.tokenizer_path = tokenizer_path
self.tokenizer: PreTrainedTokenizerBase | LLAMATokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.tokenizer: PreTrainedTokenizerBase | SPieceTokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
@ -609,15 +628,17 @@ class SDTokenizer:
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
def state_dict(self):
return {}
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory)
self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {}
@ -640,6 +661,8 @@ class SD1Tokenizer:
sd1_tokenizer.sd_tokenizer = self.sd_tokenizer.clone()
return sd1_tokenizer
def state_dict(self):
return {}
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):

View File

@ -21,12 +21,12 @@ class SDXLClipG(sd1_clip.SDClipModel):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
def __init__(self, tokenizer_path=None, embedding_directory=None, **kwargs):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class SDXLTokenizer:
def __init__(self, embedding_directory=None):
def __init__(self, embedding_directory=None, **kwargs):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
@ -39,6 +39,9 @@ class SDXLTokenizer:
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return {}
def clone(self):
sdxl_tokenizer = copy.copy(self)
sdxl_tokenizer.clip_l = self.clip_l.clone()
@ -81,13 +84,13 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):

View File

@ -3,11 +3,12 @@ from . import model_base
from . import utils
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import sa_t5
from .text_encoders import aura_t5
from .text_encoders import hydit
from . import supported_models_base
from . import latent_formats
@ -580,6 +581,45 @@ class AuraFlow(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(aura_t5.AuraT5Tokenizer, aura_t5.AuraT5Model)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
class HunyuanDiT(supported_models_base.BASE):
unet_config = {
"image_model": "hydit",
}
unet_extra_config = {
"attn_precision": torch.float32,
}
sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.018,
}
latent_format = latent_formats.SDXL
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanDiT(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(hydit.HyditTokenizer, hydit.HyditModel)
class HunyuanDiT1(HunyuanDiT):
unet_config = {
"image_model": "hydit1",
}
unet_extra_config = {}
sampling_settings = {
"linear_start" : 0.00085,
"linear_end" : 0.03,
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1]
models += [SVD_img2vid]

View File

@ -1,11 +1,10 @@
from importlib import resources
from comfy import sd1_clip
from .llama_tokenizer import LLAMATokenizer
from .spiece_tokenizer import SPieceTokenizer
from ..text_encoders import t5
from ..component_model.files import get_path_as_dict
class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_pile_config_xl.json", package=__package__)
@ -13,14 +12,16 @@ class PT5XlModel(sd1_clip.SDClipModel):
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
def __init__(self, embedding_directory=None, **kwargs):
tokenizer_path = resources.files("comfy.text_encoders.t5_pile_tokenizer") / "tokenizer.model"
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel):

139
comfy/text_encoders/bert.py Normal file
View File

@ -0,0 +1,139 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
class BertAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.query(x)
k = self.key(x)
v = self.value(x)
out = optimized_attention(q, k, v, self.heads, mask)
return out
class BertOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device)
# self.dropout = nn.Dropout(0.0)
def forward(self, x, y):
x = self.dense(x)
# hidden_states = self.dropout(hidden_states)
x = self.LayerNorm(x + y)
return x
class BertAttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.self = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
y = self.self(x, mask, optimized_attention)
return self.output(y, x)
class BertIntermediate(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device)
def forward(self, x):
x = self.dense(x)
return torch.nn.functional.gelu(x)
class BertBlock(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations)
self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations)
self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
x = self.attention(x, mask, optimized_attention)
y = self.intermediate(x)
return self.output(y, x)
class BertEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class BertEmbeddings(torch.nn.Module):
def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.word_embeddings = torch.nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device)
self.position_embeddings = torch.nn.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device)
self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device)
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, input_tokens, token_type_ids=None):
x = self.word_embeddings(input_tokens)
x += self.position_embeddings.weight[:x.shape[1]]
if token_type_ids is not None:
x += self.token_type_embeddings(token_type_ids)
else:
x += self.token_type_embeddings.weight[0]
x = self.LayerNorm(x)
return x
class BertModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
embed_dim = config_dict["hidden_size"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
x, i = self.encoder(x, mask, intermediate_output)
return x, i
class BertModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.bert = BertModel_(config_dict, dtype, device, operations)
self.num_layers = config_dict["num_hidden_layers"]
def get_input_embeddings(self):
return self.bert.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.bert.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.bert(*args, **kwargs)

View File

@ -0,0 +1,92 @@
from importlib import resources
import torch
from transformers import BertTokenizer
import comfy.text_encoders.t5
from comfy import sd1_clip
from .bert import BertModel
from .spiece_tokenizer import SPieceTokenizer
from ..component_model.files import get_path_as_dict, get_package_as_path
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = get_path_as_dict(textmodel_json_config, "hydit_clip.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, **kwargs):
tokenizer_path = get_package_as_path(f"{__package__}.hydit_clip_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = get_path_as_dict(textmodel_json_config, "mt5_config_xl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_data=None, **kwargs):
if tokenizer_data is None:
tokenizer_data = dict()
if not "spiece_model" in tokenizer_data:
raise FileNotFoundError("expected a checkpoint that contains the mt5 tokenizer's sentencepiece model")
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class HyditTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None or "mt5xl.spiece_model" not in tokenizer_data:
raise FileNotFoundError("expected mt5xl tokenizer data in the checkpoint")
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.hydit_clip.untokenize(token_weight_pair)
def state_dict(self):
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.hydit_clip = HyditBertModel()
self.mt5xl = MT5XLModel()
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def encode_token_weights(self, token_weight_pairs):
hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
def load_sd(self, sd):
if "bert.encoder.layer.0.attention.self.query.weight" in sd:
return self.hydit_clip.load_sd(sd)
else:
return self.mt5xl.load_sd(sd)
def set_clip_options(self, options):
self.hydit_clip.set_clip_options(options)
self.mt5xl.set_clip_options(options)
def reset_clip_options(self):
self.hydit_clip.reset_clip_options()
self.mt5xl.reset_clip_options()

View File

@ -0,0 +1,35 @@
{
"_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": null,
"directionality": "bidi",
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"output_past": true,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"torch_dtype": "float32",
"transformers_version": "4.22.1",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 47020
}

View File

@ -0,0 +1,7 @@
{
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"unk_token": "[UNK]"
}

View File

@ -0,0 +1,16 @@
{
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": true,
"mask_token": "[MASK]",
"name_or_path": "hfl/chinese-roberta-wwm-ext",
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]",
"model_max_length": 77
}

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +0,0 @@
class LLAMATokenizer:
# todo: not sure why we're not using the tokenizer from transformers for this
@staticmethod
def from_pretrained(path):
return LLAMATokenizer(path)
def __init__(self, tokenizer_path):
import sentencepiece
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path) # pylint: disable=unexpected-keyword-arg
self.end = self.tokenizer.eos_id()
self.eos_token_id = self.end
self.eos_token = self.tokenizer.id_to_piece(self.eos_token_id) # pylint: disable=no-member
self._vocab = {
self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size()) # pylint: disable=no-member
}
def get_vocab(self):
return self._vocab
def __call__(self, string):
out = self.tokenizer.encode(string) # pylint: disable=no-member
out += [self.end]
return {"input_ids": out}

View File

@ -0,0 +1,22 @@
{
"d_ff": 5120,
"d_kv": 64,
"d_model": 2048,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "mt5",
"num_decoder_layers": 24,
"num_heads": 32,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 250112
}

View File

@ -12,14 +12,16 @@ class T5BaseModel(sd1_clip.SDClipModel):
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
def __init__(self, *args, **kwargs):
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel):

View File

@ -1,6 +1,6 @@
from . import sd1_clip
from .. import sd1_clip
from .component_model.files import get_path_as_dict
from ..component_model.files import get_path_as_dict
class SD2ClipHModel(sd1_clip.SDClipModel):
@ -14,13 +14,15 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
def __init__(self, tokenizer_path=None, embedding_directory=None, **kwargs):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel):

View File

@ -17,23 +17,17 @@ class T5XXLModel(sd1_clip.SDClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
class SDT5XXLModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
class SD3Tokenizer:
def __init__(self, embedding_directory=None):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@ -48,6 +42,8 @@ class SD3Tokenizer:
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
def state_dict(self):
return dict()
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):

View File

@ -0,0 +1,38 @@
import sentencepiece
import torch
class SPieceTokenizer:
add_eos = True
@staticmethod
def from_pretrained(path):
return SPieceTokenizer(path)
def __init__(self, tokenizer_path):
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
construction_args = {}
if isinstance(tokenizer_path, bytes):
construction_args["model_proto"] = tokenizer_path
else:
construction_args["model_file"] = tokenizer_path
self.tokenizer = sentencepiece.SentencePieceProcessor(add_eos=SPieceTokenizer.add_eos, **construction_args) # pylint: disable=unexpected-keyword-arg
self.end = self.tokenizer.eos_id()
self.eos_token_id = self.end
self.eos_token = self.tokenizer.id_to_piece(self.eos_token_id) # pylint: disable=no-member
self._vocab = {
self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size()) # pylint: disable=no-member
}
def get_vocab(self):
return self._vocab
def __call__(self, string):
out = self.tokenizer.encode(string)
return {"input_ids": out}
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))

View File

@ -223,7 +223,7 @@ class T5(torch.nn.Module):
self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations)
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)

View File

@ -149,6 +149,15 @@ app.registerExtension({
}
audioWidget.callback = onAudioWidgetUpdate
// Load saved audio file widget values if restoring from workflow
const onGraphConfigured = node.onGraphConfigured;
node.onGraphConfigured = function() {
onGraphConfigured?.apply(this, arguments)
if (audioWidget.value) {
onAudioWidgetUpdate()
}
}
const fileInput = document.createElement("input")
fileInput.type = "file"
fileInput.accept = "audio/*"

View File

@ -136,6 +136,9 @@ class ComfyApi extends EventTarget {
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_success":
this.dispatchEvent(new CustomEvent("execution_success", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break;

View File

@ -105,15 +105,16 @@ export class ChangeTracker {
window.addEventListener(
"keydown",
(e) => {
const activeEl = document.activeElement;
requestAnimationFrame(async () => {
let activeEl;
let bindInputEl;
// If we are auto queue in change mode then we do want to trigger on inputs
if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") {
activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.["type"] === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
bindInputEl = activeEl;
}
keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta";
@ -123,7 +124,7 @@ export class ChangeTracker {
if (await changeTracker().undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done
if (ChangeTracker.bindInput(activeEl)) return;
if (ChangeTracker.bindInput(bindInputEl)) return;
changeTracker().checkState();
});
},

View File

@ -49,7 +49,7 @@ export function getPngMetadata(file) {
function parseExifData(exifData) {
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
const isLittleEndian = String.fromCharCode(...exifData.slice(0, 2)) === "II";
// Function to read 16-bit and 32-bit integers from binary data
function readInt(offset, isLittleEndian, length) {
@ -134,6 +134,7 @@ export function getWebpMetadata(file) {
let index = value.indexOf(':');
txt_chunks[value.slice(0, index)] = value.slice(index + 1);
}
break;
}
offset += 8 + chunk_length;

View File

@ -330,6 +330,7 @@
.comfyui-workflows-open .active {
font-weight: bold;
color: var(--primary-fg);
}
.comfyui-workflows-favorites:empty {
@ -417,6 +418,10 @@
padding: 2px 4px;
}
.comfyui-workflows-tree-file.active .comfyui-workflows-file-action {
color: var(--primary-fg);
}
.lg ~ .comfyui-workflows-popup .comfyui-workflows-tree-file:not(:hover) .comfyui-workflows-file-action {
opacity: 0;
}

View File

@ -151,7 +151,7 @@ class SaveAudio:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
for (batch_number, waveform) in enumerate(audio["waveform"]):
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.flac"

View File

@ -1,20 +1,10 @@
UNION_CONTROLNET_TYPES = {"auto": -1,
"openpose": 0,
"depth": 1,
"hed/pidi/scribble/ted": 2,
"canny/lineart/anime_lineart/mlsd": 3,
"normal": 4,
"segment": 5,
"tile": 6,
"repaint": 7,
}
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
class SetUnionControlNetType:
@classmethod
def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ),
"type": (list(UNION_CONTROLNET_TYPES.keys()),)
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
}}
CATEGORY = "conditioning/controlnet"
@ -24,7 +14,7 @@ class SetUnionControlNetType:
def set_controlnet_type(self, control_net, type):
control_net = control_net.copy()
type_number = UNION_CONTROLNET_TYPES[type]
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
if type_number >= 0:
control_net.set_extra_arg("control_type", [type_number])
else:

View File

@ -114,6 +114,25 @@ class SDTurboScheduler:
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
class BetaSamplingScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
"beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps, alpha, beta):
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
return (sigmas, )
class VPScheduler:
@classmethod
def INPUT_TYPES(s):
@ -279,6 +298,23 @@ class SamplerDPMPP_SDE:
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerDPMPP_2S_Ancestral:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise):
sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerEulerAncestral:
@classmethod
def INPUT_TYPES(s):
@ -641,6 +677,7 @@ NODE_CLASS_MAPPINGS = {
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral,
@ -649,6 +686,7 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise,
@ -665,4 +703,4 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++",
}
}

View File

@ -58,7 +58,7 @@ class FreeU:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches"
CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
@ -97,7 +97,7 @@ class FreeU_V2:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches"
CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]

View File

@ -0,0 +1,24 @@
class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, bert, mt5xl):
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
}

View File

@ -32,7 +32,7 @@ class HyperTile:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches"
CATEGORY = "model_patches/unet"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]

View File

@ -19,7 +19,7 @@ class PerturbedAttentionGuidance:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
CATEGORY = "model_patches/unet"
def patch(self, model, scale):
unet_block = "middle"

View File

@ -6,6 +6,113 @@ from comfy.model_downloader import add_known_models, KNOWN_LORAS
from comfy.model_downloader_types import CivitFile
_workflows = {
"hunyuandit_1": {
"3": {
"inputs": {
"seed": 377072733774956,
"steps": 1,
"cfg": 6,
"sampler_name": "euler",
"scheduler": "sgm_uniform",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "hunyuan_dit_1.2.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "a bottle with a 彩虹星系 inside it on top of a wooden table on a snowy mountain top with the ocean and clouds in the background\n",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"7": {
"inputs": {
"text": "",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
},
"audio_1": {
"14": {
"inputs": {