Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-07-15 14:22:49 -07:00
commit 3d1d833e6f
42 changed files with 1300 additions and 290 deletions

View File

@ -1,45 +1,48 @@
name: Bug Report
description: "Something is broken inside of ComfyUI. (Do not use this if you're just having issues and need help, or if the issue relates to a custom node)"
labels: [ "Potential Bug" ]
labels: ["Potential Bug"]
body:
- type: markdown
attributes:
value: |
Before submitting a **Bug Report**, please ensure the following:
- type: markdown
attributes:
value: |
Before submitting a **Bug Report**, please ensure the following:
**1:** You are running the latest version of ComfyUI.
**2:** You have looked at the existing bug reports and made sure this isn't already reported.
**3:** This is an actual bug in ComfyUI, not just a support question and not caused by an custom node. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
- **1:** You are running the latest version of ComfyUI.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: textarea
attributes:
label: Expected Behavior
description: "What you expected to happen."
validations:
required: true
- type: textarea
attributes:
label: Actual Behavior
description: "What actually happened. Please include a screenshot of the issue if possible."
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: "Describe how to reproduce the issue. Please be sure to attach a workflow JSON or PNG, ideally one that doesn't require custom nodes to test. If the bug open happens when certain custom nodes are used, most likely that custom node is what has the bug rather than ComfyUI, in which case it should be reported to the node's author."
validations:
required: true
- type: textarea
attributes:
label: Debug Logs
description: "Please copy the output from your terminal logs here."
render: powershell
validations:
required: true
- type: textarea
attributes:
label: Other
description: "Any other additional information you think might be helpful."
validations:
required: false
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: textarea
attributes:
label: Expected Behavior
description: "What you expected to happen."
validations:
required: true
- type: textarea
attributes:
label: Actual Behavior
description: "What actually happened. Please include a screenshot of the issue if possible."
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: "Describe how to reproduce the issue. Please be sure to attach a workflow JSON or PNG, ideally one that doesn't require custom nodes to test. If the bug open happens when certain custom nodes are used, most likely that custom node is what has the bug rather than ComfyUI, in which case it should be reported to the node's author."
validations:
required: true
- type: textarea
attributes:
label: Debug Logs
description: "Please copy the output from your terminal logs here."
render: powershell
validations:
required: true
- type: textarea
attributes:
label: Other
description: "Any other additional information you think might be helpful."
validations:
required: false

109
.github/workflows/stable-release.yml vendored Normal file
View File

@ -0,0 +1,109 @@
name: "Release Stable Version"
on:
push:
tags:
- 'v*'
jobs:
package_comfy_windows:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
runs-on: windows-latest
strategy:
matrix:
python_version: [3.11.8]
cuda_version: [121]
steps:
- name: Calculate Minor Version
shell: bash
run: |
# Extract the minor version from the Python version
MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2)
echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- shell: bash
run: |
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps
mv cu${{ matrix.cuda_version }}_python_deps ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
pwd
ls
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo ${{ env.MINOR_VERSION }}
echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe --version
echo "Pip version:"
./python.exe -m pip --version
set PATH=$PWD/Scripts:$PATH
echo $PATH
./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
mv python_embeded ComfyUI_windows_portable
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
cd ComfyUI_windows_portable
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
cd ComfyUI_windows_portable
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
ls
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ github.ref }}
overwrite: true

View File

@ -41,7 +41,7 @@ jobs:
working-directory: ComfyUI
- name: Start ComfyUI server
run: |
python main.py --cpu &
python main.py --cpu 2>&1 | tee console_output.log &
wait-for-it --service 127.0.0.1:8188 -t 600
working-directory: ComfyUI
- name: Install ComfyUI_frontend dependencies
@ -54,9 +54,22 @@ jobs:
- name: Run Playwright tests
run: npx playwright test
working-directory: ComfyUI_frontend
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
echo "Unhandled exception/error found in server log."
exit 1
fi
working-directory: ComfyUI
- uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: ComfyUI_frontend/playwright-report/
retention-days: 30
- uses: actions/upload-artifact@v4
if: always()
with:
name: console-output
path: ComfyUI/console_output.log
retention-days: 30

View File

@ -42,6 +42,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.

View File

@ -10,10 +10,51 @@ from ..ldm.modules.diffusionmodules.util import (
timestep_embedding,
)
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.attention import SpatialTransformer, optimized_attention
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
from .. import ops
from collections import OrderedDict
class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.c = c
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.in_proj(x)
q, k, v = x.split(self.c, dim=2)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResBlockUnionControlnet(nn.Module):
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
def attention(self, x: torch.Tensor):
return self.attn(x)
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class ControlledUnetModel(UNetModel):
#implemented in the ldm unet
@ -53,6 +94,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None,
transformer_depth_output=None,
attn_precision=None,
union_controlnet_num_control_type=None,
device=None,
operations=ops.disable_weight_init,
**kwargs,
@ -280,6 +322,65 @@ class ControlNet(nn.Module):
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch
if union_controlnet_num_control_type is not None:
self.num_control_type = union_controlnet_num_control_type
num_trans_channel = 320
num_trans_head = 8
num_trans_layer = 1
num_proj_channel = 320
# task_scale_factor = num_trans_channel ** 0.5
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
#-----------------------------------------------------------------------------------------------------
control_add_embed_dim = 256
class ControlAddEmbedding(nn.Module):
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
super().__init__()
self.num_control_type = num_control_type
self.in_dim = in_dim
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
def forward(self, control_type, dtype, device):
c_type = torch.zeros((self.num_control_type,), device=device)
c_type[control_type] = 1.0
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
else:
self.task_embedding = None
self.control_add_embedding = None
def union_controlnet_merge(self, hint, control_type, emb, context):
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
inputs = []
condition_list = []
for idx in range(min(1, len(control_type))):
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond)
x = torch.cat(inputs, dim=1)
x = self.transformer_layes(x)
controlnet_cond_fuser = None
for idx in range(len(control_type)):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
o = condition_list[idx] + alpha
if controlnet_cond_fuser is None:
controlnet_cond_fuser = o
else:
controlnet_cond_fuser += o
return controlnet_cond_fuser
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
@ -287,7 +388,18 @@ class ControlNet(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
guided_hint = None
if self.control_add_embedding is not None: #Union Controlnet
control_type = kwargs.get("control_type", [])
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0:
if len(hint.shape) < 5:
hint = hint.unsqueeze(dim=0)
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
if guided_hint is None:
guided_hint = self.input_hint_block(hint, emb, context)
out_output = []
out_middle = []

View File

@ -1,3 +1,4 @@
from .component_model import files
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os
import torch
@ -30,9 +31,17 @@ def clip_preprocess(image, size=224):
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
def __init__(self, json_config: dict | str):
if isinstance(json_config, dict):
config = json_config
elif json_config is not None and isinstance(json_config, str):
if json_config.startswith("{"):
config = json.loads(json_config)
else:
with open(json_config) as f:
config = json.load(f)
else:
raise ValueError(f"json_config had invalid value={json_config}")
self.load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
@ -88,12 +97,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys:
sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
# todo: fix the importlib issue here
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
json_config = files.get_path_as_dict(None, "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
json_config = files.get_path_as_dict(None, "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
else:
return None

View File

@ -85,7 +85,8 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def cleanup_temp():
try:
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
folder_paths.get_temp_directory()
temp_dir = folder_paths.get_temp_directory()
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
except NameError:
@ -115,7 +116,7 @@ async def main():
# configure extra model paths earlier
try:
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
extra_model_paths_config_path = os.path.join(os.getcwd(), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
except NameError:

View File

@ -439,6 +439,7 @@ class PromptServer(ExecutorToClientProgress):
info['name'] = node_class
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
@ -845,18 +846,9 @@ class PromptServer(ExecutorToClientProgress):
return json_data
@classmethod
def get_output_path(cls, subfolder: str | None = None, filename: str | None = None):
paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""]
return os.path.join(os.path.dirname(os.path.realpath(__file__)), *paths)
@classmethod
def get_upload_dir(cls) -> str:
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../input")
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
return upload_dir
return folder_paths.get_input_directory()
@classmethod
def get_too_busy_queue_size(cls):

View File

@ -19,7 +19,7 @@ def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_
config: dict | None = None
if config_dict_or_path is None:
config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package)
config_dict_or_path = config_path_inside_package
if isinstance(config_dict_or_path, str):
if config_dict_or_path.startswith("{"):

View File

@ -412,6 +412,12 @@ def load_controlnet(ckpt_path, model=None):
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
for k in list(controlnet_data.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
new_sd[new_k] = controlnet_data.pop(k)
leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))

479
comfy/ldm/aura/mmdit.py Normal file
View File

@ -0,0 +1,479 @@
#AuraFlow MMDiT
#Originally written by the AuraFlow Authors
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
n_hidden = int(2 * hidden_dim / 3)
n_hidden = find_multiple(n_hidden, 256)
self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class MultiHeadLayerNorm(nn.Module):
def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
super().__init__()
self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(
variance + self.variance_epsilon
)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class SingleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
#@torch.compile()
def forward(self, c):
bsz, seqlen1, _ = c.shape
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c = self.w1o(output)
return c
class DoubleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
# this is for x
self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.q_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
self.k_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
if mh_qknorm
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
)
#@torch.compile()
def forward(self, c, x):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
seqlen = seqlen1 + seqlen2
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
# concat all
q, k, v = (
torch.cat([cq, xq], dim=1),
torch.cat([ck, xk], dim=1),
torch.cat([cv, xv], dim=1),
)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
x = self.w2o(x)
return c, x
class MMDiTBlock(nn.Module):
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
super().__init__()
self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
if not is_last:
self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
self.modC = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
else:
self.modC = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
)
self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
self.modX = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
self.is_last = is_last
#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
cres, xres = c, x
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
self.modC(global_cond).chunk(6, dim=1)
)
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
# xpath
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
self.modX(global_cond).chunk(6, dim=1)
)
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
c = cres + c
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
x = xres + x
return c, x
class DiTBlock(nn.Module):
# like MMDiTBlock, but it only has X
def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
self.modCX = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
)
self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
cx = cxres + cx
return cx
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = 1000 * torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
#@torch.compile()
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class MMDiT(nn.Module):
def __init__(
self,
in_channels=4,
out_channels=4,
patch_size=2,
dim=3072,
n_layers=36,
n_double_layers=4,
n_heads=12,
global_conddim=3072,
cond_seq_dim=2048,
max_seq=32 * 32,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
self.cond_seq_linear = operations.Linear(
cond_seq_dim, dim, bias=False, dtype=dtype, device=device
) # linear for something like text sequence.
self.init_x_linear = operations.Linear(
patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
) # init linear for patchified image.
self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
self.double_layers = nn.ModuleList([])
self.single_layers = nn.ModuleList([])
for idx in range(n_double_layers):
self.double_layers.append(
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
)
for idx in range(n_double_layers, n_layers):
self.single_layers.append(
DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
)
self.final_linear = operations.Linear(
dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
)
self.modF = nn.Sequential(
nn.SiLU(),
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
)
self.out_channels = out_channels
self.patch_size = patch_size
self.n_double_layers = n_double_layers
self.n_layers = n_layers
self.h_max = round(max_seq**0.5)
self.w_max = round(max_seq**0.5)
@torch.no_grad()
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
# extend pe
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
# now we need to extend this to target_dim. for this we will use interpolation.
# we will use torch.nn.functional.interpolate
pe_as_2d = F.interpolate(
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
)
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
self.h_max, self.w_max = target_dim
print("PE extended to", target_dim)
def pe_selection_index_based_on_dim(self, h, w):
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
starth = self.h_max // 2 - h_p // 2
endh =starth + h_p
startw = self.w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[
starth:endh, startw:endw
]
return original_pe_indexes.flatten()
def unpatchify(self, x, h, w):
c = self.out_channels
p = self.patch_size
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x = x.view(
B,
C,
(H + 1) // self.patch_size,
self.patch_size,
(W + 1) // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
def apply_pos_embeds(self, x, h, w):
h = (h + 1) // self.patch_size
w = (w + 1) // self.patch_size
max_dim = max(h, w)
cur_dim = self.h_max
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
if max_dim > cur_dim:
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
cur_dim = max_dim
from_h = (cur_dim - h) // 2
from_w = (cur_dim - w) // 2
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, **kwargs):
# patchify x, add PE
b, c, h, w = x.shape
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
# print(pe_indexes, pe_indexes.shape)
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
x = self.apply_pos_embeds(x, h, w)
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
# process conditions for MMDiT Blocks
c_seq = context # B, T_c, D_c
t = timestep
c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t, x.dtype) # B, D
if len(self.double_layers) > 0:
for layer in self.double_layers:
c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
cx = torch.cat([c, x], dim=1)
for layer in self.single_layers:
cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:]
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
x = modulate(x, fshift, fscale)
x = self.final_linear(x)
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
return x

View File

@ -272,4 +272,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
key_map[key_lora] = to
if isinstance(model, model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
return key_map

View File

@ -17,7 +17,7 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
class ModelType(Enum):
EPS = 1
@ -622,6 +622,17 @@ class SD3(BaseModel):
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)
class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=AuraMMDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
return out
class StableAudio1(BaseModel):
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):

View File

@ -104,6 +104,19 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["audio_model"] = "dit1.0"
return unet_config
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
unet_config["n_layers"] = double_layers + single_layers
return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
unet_config = {
"use_checkpoint": False,
"image_size": 32,
@ -238,6 +251,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix)
if unet_config is None:
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
if model_config is None and use_base_if_no_match:
return supported_models_base.BASE(unet_config)
@ -247,6 +262,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models
unet_key_prefix = "model.model."
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
unet_key_prefix = "model."
else:
unet_key_prefix = "model.diffusion_model."
return unet_key_prefix
@ -436,38 +453,45 @@ def model_config_from_diffusers_unet(state_dict):
return None
def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = None
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
if num_blocks > 0:
out_sd = {}
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
out_sd = {}
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
for k in sd_map:
weight = state_dict.get(k, None)
if weight is not None:
t = sd_map[k]
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else:
return None
if not isinstance(t, str):
if len(t) > 2:
fun = t[2]
else:
fun = lambda a: a
offset = t[1]
if offset is not None:
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
for k in sd_map:
weight = state_dict.get(k, None)
if weight is not None:
t = sd_map[k]
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
old_weight = weight
w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
if not isinstance(t, str):
if len(t) > 2:
fun = t[2]
else:
out_sd[t] = weight
state_dict.pop(k)
fun = lambda a: a
offset = t[1]
if offset is not None:
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
old_weight = weight
w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
else:
out_sd[t] = weight
state_dict.pop(k)
return out_sd

View File

@ -60,6 +60,12 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
model_options["disable_cfg1_optimization"] = True
return model_options
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
if disable_cfg1_optimization:
model_options["disable_cfg1_optimization"] = True
return model_options
class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
@ -142,6 +148,9 @@ class ModelPatcher(ModelManageable):
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
self.model_options["model_function_wrapper"] = unet_wrapper_function

View File

@ -192,11 +192,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
def set_parameters(self, shift=1.0, timesteps=1000):
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
self.shift = shift
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.multiplier = multiplier
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
self.register_buffer('sigmas', ts)
@property
@ -208,10 +209,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
return sigma * self.multiplier
def sigma(self, timestep):
return time_snr_shift(self.shift, timestep / 1000)
return time_snr_shift(self.shift, timestep / self.multiplier)
def percent_to_sigma(self, percent):
if percent <= 0.0:

View File

@ -46,8 +46,9 @@ class CLIPTextEncode:
def encode(self, clip, text):
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return ([[cond, {"pooled_output": pooled}]], )
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
class ConditioningCombine:
@classmethod
@ -223,8 +224,9 @@ class ConditioningZeroOut:
c = []
for t in conditioning:
d = t[1].copy()
if "pooled_output" in d:
d["pooled_output"] = torch.zeros_like(d["pooled_output"])
pooled_output = d.get("pooled_output", None)
if pooled_output is not None:
d["pooled_output"] = torch.zeros_like(pooled_output)
n = [torch.zeros_like(t[0]), d]
c.append(n)
return (c, )

View File

@ -1,22 +1,27 @@
from comfy import sd1_clip
from transformers import T5TokenizerFast
import comfy.t5
import os
from comfy import sd1_clip
from comfy.component_model import files
class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs)
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)

View File

@ -278,6 +278,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = fn(args)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)

View File

@ -28,37 +28,7 @@ from .t2i_adapter import adapter
from .taesd import taesd
from . import sd3_clip
from . import sa_t5
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
m = set(m)
unexpected_keys = set(u)
k = list(sd.keys())
for x in k:
if x not in unexpected_keys:
w = sd.pop(x)
del w
if len(m) > 0:
logging.warning("missing {}".format(m))
return model
def load_clip_weights(model, sd):
k = list(sd.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
sd[y] = sd.pop(x)
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
return load_model_weights(model, sd)
from .text_encoders import aura_t5
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
@ -136,7 +106,7 @@ class CLIP:
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False):
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
@ -146,7 +116,15 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
if len(o) > 2:
for k in o[2]:
out[k] = o[2][k]
return out
if return_pooled:
return cond, pooled
return cond
@ -447,9 +425,14 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype
if weight.shape[-1] == 4096:
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048:
clip_target.clip = aura_t5.AuraT5Model
clip_target.tokenizer = aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer
@ -529,13 +512,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
if model_config.clip_vision_prefix is not None:
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
@ -586,42 +569,37 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): # load unet in diffusers or regular format
#Allow loading unets from checkpoint files
checkpoint = False
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0:
sd = temp_sd
checkpoint = True
parameters = utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: # ldm or stable cascade
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
if model_config is not None:
new_sd = sd
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is None:
return None
model_config = model_detection.model_config_from_unet(new_sd, "")
if model_config is None:
return None
else: # diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
if new_sd is not None: #diffusers mmdit
model_config = model_detection.model_config_from_unet(new_sd, "")
if model_config is None:
return None
else: # diffusers unet
model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)

View File

@ -1,11 +1,13 @@
from __future__ import annotations
import copy
import importlib.resources
import logging
import numbers
import os
import traceback
import zipfile
from importlib.abc import Traversable
from typing import Tuple, Sequence, TypeVar
import torch
@ -14,6 +16,7 @@ from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMi
from . import clip_model
from . import model_management
from . import ops
from .component_model import files
from .component_model.files import get_path_as_dict, get_package_as_path
@ -29,7 +32,58 @@ def gen_empty_tokens(special_tokens, length):
output += [pad_token] * (length - len(output))
return output
class SDClipModel(torch.nn.Module):
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
o = self.encode(to_encode)
out, pooled = o[:2]
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k+1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
@ -40,7 +94,7 @@ class SDClipModel(torch.nn.Module):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True): # clip-vit-base-patch32
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
super().__init__()
if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
@ -63,6 +117,7 @@ class SDClipModel(torch.nn.Module):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
if layer == "hidden":
assert layer_idx is not None
@ -136,7 +191,7 @@ class SDClipModel(torch.nn.Module):
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
attention_mask = None
if self.enable_attention_masks:
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1)
for x in range(attention_mask.shape[0]):
@ -145,7 +200,11 @@ class SDClipModel(torch.nn.Module):
if tokens[x, y] == end_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
@ -153,7 +212,7 @@ class SDClipModel(torch.nn.Module):
else:
z = outputs[1].float()
if self.zero_out_masked and attention_mask is not None:
if self.zero_out_masked:
z *= attention_mask.unsqueeze(-1).float()
pooled_output = None
@ -163,6 +222,13 @@ class SDClipModel(torch.nn.Module):
elif outputs[2] is not None:
pooled_output = outputs[2].float()
extra = {}
if self.return_attention_masks:
extra["attention_mask"] = attention_mask
if len(extra) > 0:
return z, pooled_output, extra
return z, pooled_output
def encode(self, tokens):
@ -374,10 +440,13 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
tokenizer_path = files.get_package_as_path("comfy.sd1_tokenizer")
if isinstance(tokenizer_path, Traversable):
contextlib_path = importlib.resources.as_file(tokenizer_path)
tokenizer_path = contextlib_path.__enter__()
if not tokenizer_path.endswith(".model") and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# package based
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
self.tokenizer_class = tokenizer_class
@ -395,6 +464,14 @@ class SDTokenizer:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
if pad_token is not None:
self.pad_token = pad_token
elif pad_with_end:
self.pad_token = self.end_token
else:
self.pad_token = 0
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
self.additional_tokens: Tuple[str, ...] = ()
@ -439,10 +516,6 @@ class SDTokenizer:
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
@ -502,7 +575,7 @@ class SDTokenizer:
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
# start new batch
batch = []
if self.start_token is not None:
@ -515,9 +588,9 @@ class SDTokenizer:
# fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
@ -560,10 +633,16 @@ class SD1Tokenizer:
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
super().__init__()
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
if name is not None:
self.clip_name = name
self.clip = "{}".format(self.clip_name)
else:
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
self.dtypes = set()
@ -578,8 +657,8 @@ class SD1ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)

View File

@ -1,39 +1,45 @@
import logging
import os
import torch
from transformers import T5TokenizerFast
import comfy.model_management
import comfy.t5
from comfy import sd1_clip
from comfy import sdxl_clip
from transformers import T5TokenizerFast
import comfy.t5
import torch
import os
import comfy.model_management
import logging
from comfy.component_model import files
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
class SDT5XXLModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
class SD3Tokenizer:
def __init__(self, embedding_directory=None):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
@ -43,6 +49,7 @@ class SD3Tokenizer:
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
@ -143,8 +150,10 @@ class SD3ClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
return SD3ClipModel_

View File

@ -7,6 +7,7 @@ from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
from .text_encoders import aura_t5
from . import supported_models_base
from . import latent_formats
@ -556,7 +557,29 @@ class StableAudio(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
class AuraFlow(supported_models_base.BASE):
unet_config = {
"cond_seq_dim": 2048,
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
sampling_settings = {
"multiplier": 1.0,
"shift": 1.73,
}
unet_extra_config = {}
latent_format = latent_formats.SDXL
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.AuraFlow(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(aura_t5.AuraT5Tokenizer, aura_t5.AuraT5Model)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
models += [SVD_img2vid]

View File

@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module):
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x
activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
"relu": torch.nn.functional.relu,
}
class T5DenseActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device, operations):
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
super().__init__()
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
# self.dropout = nn.Dropout(config.dropout_rate)
self.act = activations[ff_activation]
def forward(self, x):
x = torch.nn.functional.relu(self.wi(x))
x = self.act(self.wi(x))
# x = self.dropout(x)
x = self.wo(x)
return x
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device, operations):
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
super().__init__()
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
# self.dropout = nn.Dropout(config.dropout_rate)
self.act = activations[ff_activation]
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
hidden_gelu = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
# x = self.dropout(x)
@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module):
return x
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
super().__init__()
if ff_activation == "gelu_pytorch_tanh":
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations)
elif ff_activation == "relu":
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations)
if gated_act:
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
else:
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate)
@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module):
return x, past_bias
class T5Block(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations):
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations))
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
@ -183,11 +190,11 @@ class T5Block(torch.nn.Module):
return x, past_bias
class T5Stack(torch.nn.Module):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention, dtype, device, operations):
super().__init__()
self.block = torch.nn.ModuleList(
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0)), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
)
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate)
@ -216,7 +223,7 @@ class T5(torch.nn.Module):
self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations)
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations)
self.dtype = dtype
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)

View File

@ -8,6 +8,7 @@
"dense_act_fn": "relu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": false,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 12,

View File

@ -8,6 +8,7 @@
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,

View File

View File

@ -0,0 +1,28 @@
from importlib import resources
from comfy import sd1_clip
from .llama_tokenizer import LLAMATokenizer
from .. import t5
from ..component_model.files import get_path_as_dict
class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_pile_config_xl.json", package="comfy.text_encoders")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=t5.T5, enable_attention_masks=True, zero_out_masked=True)
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = resources.files("comfy.text_encoders.t5_pile_tokenizer") / "tokenizer.model"
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)

View File

@ -0,0 +1,22 @@
import os
class LLAMATokenizer:
@staticmethod
def from_pretrained(path):
return LLAMATokenizer(path)
def __init__(self, tokenizer_path):
import sentencepiece
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
self.end = self.tokenizer.eos_id()
def get_vocab(self):
out = {}
for i in range(self.tokenizer.get_piece_size()):
out[self.tokenizer.id_to_piece(i)] = i
return out
def __call__(self, string):
out = self.tokenizer.encode(string)
out += [self.end]
return {"input_ids": out}

View File

@ -0,0 +1,22 @@
{
"d_ff": 5120,
"d_kv": 64,
"d_model": 2048,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 2,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 24,
"num_heads": 32,
"num_layers": 24,
"output_past": true,
"pad_token_id": 1,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}

Binary file not shown.

View File

@ -20,6 +20,8 @@ from PIL import Image
from tqdm import tqdm
from . import checkpoint_pickle, interruption
from .component_model import files
from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context
@ -374,6 +376,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0)
n_layers = mmdit_config.get("n_layers", 0)
key_map = {}
for i in range(n_layers):
if i < n_double_layers:
index = i
prefix_from = "joint_transformer_blocks"
prefix_to = "{}double_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w2q.weight",
"attn.to_k.weight": "attn.w2k.weight",
"attn.to_v.weight": "attn.w2v.weight",
"attn.to_out.0.weight": "attn.w2o.weight",
"attn.add_q_proj.weight": "attn.w1q.weight",
"attn.add_k_proj.weight": "attn.w1k.weight",
"attn.add_v_proj.weight": "attn.w1v.weight",
"attn.to_add_out.weight": "attn.w1o.weight",
"ff.linear_1.weight": "mlpX.c_fc1.weight",
"ff.linear_2.weight": "mlpX.c_fc2.weight",
"ff.out_projection.weight": "mlpX.c_proj.weight",
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
"norm1.linear.weight": "modX.1.weight",
"norm1_context.linear.weight": "modC.1.weight",
}
else:
index = i - n_double_layers
prefix_from = "single_transformer_blocks"
prefix_to = "{}single_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w1q.weight",
"attn.to_k.weight": "attn.w1k.weight",
"attn.to_v.weight": "attn.w1v.weight",
"attn.to_out.0.weight": "attn.w1o.weight",
"norm1.linear.weight": "modCX.1.weight",
"ff.linear_1.weight": "mlp.c_fc1.weight",
"ff.linear_2.weight": "mlp.c_fc2.weight",
"ff.out_projection.weight": "mlp.c_proj.weight"
}
for k in block_map:
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
MAP_BASIC = {
("positional_encoding", "pos_embed.pos_embed"),
("register_tokens", "register_tokens"),
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
("cond_seq_linear.weight", "context_embedder.weight"),
("init_x_linear.weight", "pos_embed.proj.weight"),
("init_x_linear.bias", "pos_embed.proj.bias"),
("final_linear.weight", "proj_out.weight"),
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
}
for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)
@ -675,8 +747,9 @@ class ProgressBar:
self.update_absolute(self.current + value)
@_deprecate_method(version="1.0.0", message="The root project directory isn't valid when the application is installed as a package. Use os.getcwd() instead.")
def get_project_root() -> str:
return os.path.join(os.path.dirname(__file__), "..")
return files.get_package_as_path("comfy")
@contextmanager

View File

@ -1599,7 +1599,7 @@ export class ComfyApp {
if (json) {
const workflow = JSON.parse(json);
const workflowName = getStorageValue("Comfy.PreviousWorkflow");
await this.loadGraphData(workflow, true, workflowName);
await this.loadGraphData(workflow, true, true, workflowName);
return true;
}
};

View File

@ -182,6 +182,11 @@ export class ComfyWorkflowsMenu {
* @param {ComfyWorkflow} workflow
*/
async function sendToWorkflow(img, workflow) {
const openWorkflow = app.workflowManager.openWorkflows.find((w) => w.path === workflow.path);
if (openWorkflow) {
workflow = openWorkflow;
}
await workflow.load();
let options = [];
const nodes = app.graph.computeExecutionOrder(false);
@ -214,7 +219,8 @@ export class ComfyWorkflowsMenu {
nodeType.prototype["getExtraMenuOptions"] = function (_, options) {
const r = getExtraMenuOptions?.apply?.(this, arguments);
if (app.ui.settings.getSettingValue("Comfy.UseNewMenu", false) === true) {
const setting = app.ui.settings.getSettingValue("Comfy.UseNewMenu", false);
if (setting && setting != "Disabled") {
const t = /** @type { {imageIndex?: number, overIndex?: number, imgs: string[]} } */ /** @type {any} */ (this);
let img;
if (t.imageIndex != null) {

View File

@ -41,7 +41,7 @@ body {
background-color: var(--bg-color);
color: var(--fg-color);
grid-template-columns: auto 1fr auto;
grid-template-rows: auto auto 1fr auto;
grid-template-rows: auto 1fr auto;
min-height: -webkit-fill-available;
max-height: -webkit-fill-available;
min-width: -webkit-fill-available;
@ -49,32 +49,37 @@ body {
}
.comfyui-body-top {
order: 0;
order: -5;
grid-column: 1/-1;
z-index: 10;
display: flex;
flex-direction: column;
}
.comfyui-body-left {
order: 1;
order: -4;
z-index: 10;
display: flex;
}
#graph-canvas {
width: 100%;
height: 100%;
order: 2;
grid-column: 1/-1;
order: -3;
}
.comfyui-body-right {
order: 3;
order: -2;
z-index: 10;
display: flex;
}
.comfyui-body-bottom {
order: 4;
order: -1;
grid-column: 1/-1;
z-index: 10;
display: flex;
flex-direction: column;
}
.comfy-multiline-input {
@ -408,8 +413,12 @@ dialog::backdrop {
background: rgba(0, 0, 0, 0.5);
}
.comfy-dialog.comfyui-dialog {
.comfy-dialog.comfyui-dialog.comfy-modal {
top: 0;
left: 0;
right: 0;
bottom: 0;
transform: none;
}
.comfy-dialog.comfy-modal {

View File

@ -20,7 +20,7 @@ class EmptyLatentAudio:
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "_for_testing/audio"
CATEGORY = "latent/audio"
def generate(self, seconds):
batch_size = 1
@ -35,7 +35,7 @@ class VAEEncodeAudio:
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing/audio"
CATEGORY = "latent/audio"
def encode(self, vae, audio):
sample_rate = audio["sample_rate"]
@ -55,7 +55,7 @@ class VAEDecodeAudio:
RETURN_TYPES = ("AUDIO",)
FUNCTION = "decode"
CATEGORY = "_for_testing/audio"
CATEGORY = "latent/audio"
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
@ -134,7 +134,7 @@ class SaveAudio:
OUTPUT_NODE = True
CATEGORY = "_for_testing/audio"
CATEGORY = "audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
import torchaudio # pylint: disable=import-error
@ -199,7 +199,7 @@ class LoadAudio:
]
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "_for_testing/audio"
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
@ -209,7 +209,6 @@ class LoadAudio:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
multiplier = 1.0
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )

View File

@ -147,7 +147,7 @@ class ModelSamplingSD3:
CATEGORY = "advanced/model"
def patch(self, model, shift):
def patch(self, model, shift, multiplier=1000):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
@ -157,10 +157,22 @@ class ModelSamplingSD3:
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
model_sampling.set_parameters(shift=shift, multiplier=multiplier)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingAuraFlow(ModelSamplingSD3):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}),
}}
FUNCTION = "patch_aura"
def patch_aura(self, model, shift):
return self.patch(model, shift, multiplier=1.0)
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
@ -276,5 +288,6 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingContinuousV": ModelSamplingContinuousV,
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"RescaleCFG": RescaleCFG,
}

View File

@ -5,6 +5,8 @@ torchsde>=0.2.6
einops>=0.6.0
open-clip-torch>=2.24.0
transformers>=4.29.1
tokenizers>=0.13.3
sentencepiece
peft
torchinfo
safetensors>=0.4.2

View File

@ -192,6 +192,7 @@ package_data = [
't5_tokenizer/*',
'**/*.json',
'**/*.yaml',
'**/*.model'
]
if not is_editable:
package_data.append('comfy/web/**/*')

View File

@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
import jwt
import pytest
from aiohttp import ClientSession, ClientConnectorError
from aiohttp import ClientSession
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient
@ -132,13 +132,11 @@ async def test_basic_queue_worker_with_health_check():
health_check_port = 9090
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
# Test health check
health_check_url = f"http://localhost:{health_check_port}/health"
health_check_ok = await check_health(health_check_url)
assert health_check_ok, "Health check server did not start properly"
# Test the actual worker functionality
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
await distributed_queue.init()
@ -153,53 +151,5 @@ async def test_basic_queue_worker_with_health_check():
await distributed_queue.close()
# Test that the health check server is stopped after the worker is closed
health_check_stopped = not await check_health(health_check_url, max_retries=1)
assert health_check_stopped, "Health check server did not stop properly"
@pytest.mark.asyncio
async def test_health_check_port_conflict():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
health_check_port = 9090
# Start a simple server to occupy the health check port
from aiohttp import web
async def dummy_handler(request):
return web.Response(text="Dummy")
app = web.Application()
app.router.add_get('/', dummy_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, '0.0.0.0', health_check_port)
await site.start()
try:
# Now try to start the DistributedPromptWorker
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
# The health check should be disabled, but the worker should still function
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
await distributed_queue.init()
queue_item = create_test_prompt()
res = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()
# The original server should still be running
async with ClientSession() as session:
async with session.get(f"http://localhost:{health_check_port}") as response:
assert response.status == 200
assert await response.text() == "Dummy"
finally:
await runner.cleanup()
assert health_check_stopped, "Health check server did not stop properly"

View File

@ -1,7 +1,6 @@
import pytest
from comfy.api.components.schema.prompt import Prompt
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.model_downloader import add_known_models, KNOWN_LORAS
from comfy.model_downloader_types import CivitFile
@ -139,9 +138,7 @@ _workflows = {
@pytest.fixture(scope="module", autouse=False)
@pytest.mark.asyncio
async def client(tmp_path_factory) -> EmbeddedComfyClient:
config = Configuration()
config.cwd = str(tmp_path_factory.mktemp("comfy_test_cwd"))
async with EmbeddedComfyClient(config) as client:
async with EmbeddedComfyClient() as client:
yield client