mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
3d1d833e6f
9
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
9
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -7,9 +7,12 @@ body:
|
||||
value: |
|
||||
Before submitting a **Bug Report**, please ensure the following:
|
||||
|
||||
**1:** You are running the latest version of ComfyUI.
|
||||
**2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
**3:** This is an actual bug in ComfyUI, not just a support question and not caused by an custom node. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
- **1:** You are running the latest version of ComfyUI.
|
||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||
`--disable-all-custom-nodes` command line argument.
|
||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||
|
||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||
- type: textarea
|
||||
|
||||
109
.github/workflows/stable-release.yml
vendored
Normal file
109
.github/workflows/stable-release.yml
vendored
Normal file
@ -0,0 +1,109 @@
|
||||
|
||||
name: "Release Stable Version"
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
package_comfy_windows:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: [3.11.8]
|
||||
cuda_version: [121]
|
||||
steps:
|
||||
- name: Calculate Minor Version
|
||||
shell: bash
|
||||
run: |
|
||||
# Extract the minor version from the Python version
|
||||
MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2)
|
||||
echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
call update_comfyui.bat nopause
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies.
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo -
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||
echo installed basic
|
||||
ls -lah temp_wheel_dir
|
||||
mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps
|
||||
mv cu${{ matrix.cuda_version }}_python_deps ../
|
||||
mv update_comfyui_and_python_dependencies.bat ../
|
||||
cd ..
|
||||
pwd
|
||||
ls
|
||||
|
||||
cp -r ComfyUI ComfyUI_copy
|
||||
curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip
|
||||
unzip python_embeded.zip -d python_embeded
|
||||
cd python_embeded
|
||||
echo ${{ env.MINOR_VERSION }}
|
||||
echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
./python.exe --version
|
||||
echo "Pip version:"
|
||||
./python.exe -m pip --version
|
||||
|
||||
set PATH=$PWD/Scripts:$PATH
|
||||
echo $PATH
|
||||
./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||
|
||||
ls
|
||||
|
||||
- name: Upload binaries to release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
file: ComfyUI_windows_portable_nvidia.7z
|
||||
tag: ${{ github.ref }}
|
||||
overwrite: true
|
||||
|
||||
15
.github/workflows/test-browser.yml
vendored
15
.github/workflows/test-browser.yml
vendored
@ -41,7 +41,7 @@ jobs:
|
||||
working-directory: ComfyUI
|
||||
- name: Start ComfyUI server
|
||||
run: |
|
||||
python main.py --cpu &
|
||||
python main.py --cpu 2>&1 | tee console_output.log &
|
||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||
working-directory: ComfyUI
|
||||
- name: Install ComfyUI_frontend dependencies
|
||||
@ -54,9 +54,22 @@ jobs:
|
||||
- name: Run Playwright tests
|
||||
run: npx playwright test
|
||||
working-directory: ComfyUI_frontend
|
||||
- name: Check for unhandled exceptions in server log
|
||||
run: |
|
||||
if grep -qE "Exception|Error" console_output.log; then
|
||||
echo "Unhandled exception/error found in server log."
|
||||
exit 1
|
||||
fi
|
||||
working-directory: ComfyUI
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: ComfyUI_frontend/playwright-report/
|
||||
retention-days: 30
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: console-output
|
||||
path: ComfyUI/console_output.log
|
||||
retention-days: 30
|
||||
|
||||
@ -42,6 +42,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
|
||||
@ -10,10 +10,51 @@ from ..ldm.modules.diffusionmodules.util import (
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.attention import SpatialTransformer, optimized_attention
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||
from ..ldm.util import exists
|
||||
from .. import ops
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = nhead
|
||||
self.c = c
|
||||
|
||||
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_proj(x)
|
||||
q, k, v = x.split(self.c, dim=2)
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResBlockUnionControlnet(nn.Module):
|
||||
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
|
||||
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
|
||||
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
|
||||
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
return self.attn(x)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
#implemented in the ldm unet
|
||||
@ -53,6 +94,7 @@ class ControlNet(nn.Module):
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
attn_precision=None,
|
||||
union_controlnet_num_control_type=None,
|
||||
device=None,
|
||||
operations=ops.disable_weight_init,
|
||||
**kwargs,
|
||||
@ -280,6 +322,65 @@ class ControlNet(nn.Module):
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||
self._feature_size += ch
|
||||
|
||||
if union_controlnet_num_control_type is not None:
|
||||
self.num_control_type = union_controlnet_num_control_type
|
||||
num_trans_channel = 320
|
||||
num_trans_head = 8
|
||||
num_trans_layer = 1
|
||||
num_proj_channel = 320
|
||||
# task_scale_factor = num_trans_channel ** 0.5
|
||||
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
|
||||
|
||||
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
|
||||
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
|
||||
#-----------------------------------------------------------------------------------------------------
|
||||
|
||||
control_add_embed_dim = 256
|
||||
class ControlAddEmbedding(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_control_type = num_control_type
|
||||
self.in_dim = in_dim
|
||||
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
|
||||
def forward(self, control_type, dtype, device):
|
||||
c_type = torch.zeros((self.num_control_type,), device=device)
|
||||
c_type[control_type] = 1.0
|
||||
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
|
||||
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
|
||||
|
||||
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.task_embedding = None
|
||||
self.control_add_embedding = None
|
||||
|
||||
def union_controlnet_merge(self, hint, control_type, emb, context):
|
||||
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
|
||||
inputs = []
|
||||
condition_list = []
|
||||
|
||||
for idx in range(min(1, len(control_type))):
|
||||
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
|
||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||
if idx < len(control_type):
|
||||
feat_seq += self.task_embedding[control_type[idx]]
|
||||
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(controlnet_cond)
|
||||
|
||||
x = torch.cat(inputs, dim=1)
|
||||
x = self.transformer_layes(x)
|
||||
controlnet_cond_fuser = None
|
||||
for idx in range(len(control_type)):
|
||||
alpha = self.spatial_ch_projs(x[:, idx])
|
||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||
o = condition_list[idx] + alpha
|
||||
if controlnet_cond_fuser is None:
|
||||
controlnet_cond_fuser = o
|
||||
else:
|
||||
controlnet_cond_fuser += o
|
||||
return controlnet_cond_fuser
|
||||
|
||||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||
|
||||
@ -287,6 +388,17 @@ class ControlNet(nn.Module):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
guided_hint = None
|
||||
if self.control_add_embedding is not None: #Union Controlnet
|
||||
control_type = kwargs.get("control_type", [])
|
||||
|
||||
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
|
||||
if len(control_type) > 0:
|
||||
if len(hint.shape) < 5:
|
||||
hint = hint.unsqueeze(dim=0)
|
||||
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
|
||||
|
||||
if guided_hint is None:
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
|
||||
out_output = []
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .component_model import files
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
@ -30,9 +31,17 @@ def clip_preprocess(image, size=224):
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
def __init__(self, json_config: dict | str):
|
||||
if isinstance(json_config, dict):
|
||||
config = json_config
|
||||
elif json_config is not None and isinstance(json_config, str):
|
||||
if json_config.startswith("{"):
|
||||
config = json.loads(json_config)
|
||||
else:
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"json_config had invalid value={json_config}")
|
||||
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
@ -88,12 +97,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
if convert_keys:
|
||||
sd = convert_to_transformers(sd, prefix)
|
||||
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
||||
# todo: fix the importlib issue here
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_g.json")
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -85,7 +85,8 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
|
||||
def cleanup_temp():
|
||||
try:
|
||||
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
folder_paths.get_temp_directory()
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
except NameError:
|
||||
@ -115,7 +116,7 @@ async def main():
|
||||
|
||||
# configure extra model paths earlier
|
||||
try:
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
extra_model_paths_config_path = os.path.join(os.getcwd(), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config(extra_model_paths_config_path)
|
||||
except NameError:
|
||||
|
||||
@ -439,6 +439,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
info['name'] = node_class
|
||||
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
|
||||
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
info['output_node'] = True
|
||||
@ -845,18 +846,9 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
return json_data
|
||||
|
||||
@classmethod
|
||||
def get_output_path(cls, subfolder: str | None = None, filename: str | None = None):
|
||||
paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""]
|
||||
return os.path.join(os.path.dirname(os.path.realpath(__file__)), *paths)
|
||||
|
||||
@classmethod
|
||||
def get_upload_dir(cls) -> str:
|
||||
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../input")
|
||||
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir)
|
||||
return upload_dir
|
||||
return folder_paths.get_input_directory()
|
||||
|
||||
@classmethod
|
||||
def get_too_busy_queue_size(cls):
|
||||
|
||||
@ -19,7 +19,7 @@ def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_
|
||||
config: dict | None = None
|
||||
|
||||
if config_dict_or_path is None:
|
||||
config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package)
|
||||
config_dict_or_path = config_path_inside_package
|
||||
|
||||
if isinstance(config_dict_or_path, str):
|
||||
if config_dict_or_path.startswith("{"):
|
||||
|
||||
@ -412,6 +412,12 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if k in controlnet_data:
|
||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||
|
||||
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
|
||||
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
|
||||
for k in list(controlnet_data.keys()):
|
||||
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
||||
new_sd[new_k] = controlnet_data.pop(k)
|
||||
|
||||
leftover_keys = controlnet_data.keys()
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
|
||||
479
comfy/ldm/aura/mmdit.py
Normal file
479
comfy/ldm/aura/mmdit.py
Normal file
@ -0,0 +1,479 @@
|
||||
#AuraFlow MMDiT
|
||||
#Originally written by the AuraFlow Authors
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
n_hidden = find_multiple(n_hidden, 256)
|
||||
|
||||
self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
||||
self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
|
||||
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
|
||||
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = (hidden_states - mean) * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
class SingleAttention(nn.Module):
|
||||
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# this is for cond
|
||||
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
self.k_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
|
||||
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
class DoubleAttention(nn.Module):
|
||||
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# this is for cond
|
||||
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
# this is for x
|
||||
self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
self.k_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
self.q_norm2 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
self.k_norm2 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
seqlen = seqlen1 + seqlen2
|
||||
|
||||
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
|
||||
|
||||
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
|
||||
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
|
||||
|
||||
# concat all
|
||||
q, k, v = (
|
||||
torch.cat([cq, xq], dim=1),
|
||||
torch.cat([ck, xk], dim=1),
|
||||
torch.cat([cv, xv], dim=1),
|
||||
)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
|
||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||
c = self.w1o(c)
|
||||
x = self.w2o(x)
|
||||
|
||||
return c, x
|
||||
|
||||
|
||||
class MMDiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
if not is_last:
|
||||
self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
self.modC = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.modC = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
self.modX = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
||||
self.is_last = is_last
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, global_cond, **kwargs):
|
||||
|
||||
cres, xres = c, x
|
||||
|
||||
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
|
||||
self.modC(global_cond).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
|
||||
|
||||
# xpath
|
||||
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
|
||||
self.modX(global_cond).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||
|
||||
# attention
|
||||
c, x = self.attn(c, x)
|
||||
|
||||
|
||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
|
||||
c = cres + c
|
||||
|
||||
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
|
||||
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
|
||||
x = xres + x
|
||||
|
||||
return c, x
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
# like MMDiTBlock, but it only has X
|
||||
def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
|
||||
self.modCX = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, cx, global_cond, **kwargs):
|
||||
cxres = cx
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||
global_cond
|
||||
).chunk(6, dim=1)
|
||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||
cx = self.attn(cx)
|
||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
|
||||
cx = cxres + cx
|
||||
|
||||
return cx
|
||||
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = 1000 * torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
||||
).to(t.device)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class MMDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=2,
|
||||
dim=3072,
|
||||
n_layers=36,
|
||||
n_double_layers=4,
|
||||
n_heads=12,
|
||||
global_conddim=3072,
|
||||
cond_seq_dim=2048,
|
||||
max_seq=32 * 32,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.cond_seq_linear = operations.Linear(
|
||||
cond_seq_dim, dim, bias=False, dtype=dtype, device=device
|
||||
) # linear for something like text sequence.
|
||||
self.init_x_linear = operations.Linear(
|
||||
patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
|
||||
) # init linear for patchified image.
|
||||
|
||||
self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
|
||||
|
||||
self.double_layers = nn.ModuleList([])
|
||||
self.single_layers = nn.ModuleList([])
|
||||
|
||||
|
||||
for idx in range(n_double_layers):
|
||||
self.double_layers.append(
|
||||
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
|
||||
for idx in range(n_double_layers, n_layers):
|
||||
self.single_layers.append(
|
||||
DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
|
||||
|
||||
self.final_linear = operations.Linear(
|
||||
dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.modF = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
self.n_double_layers = n_double_layers
|
||||
self.n_layers = n_layers
|
||||
|
||||
self.h_max = round(max_seq**0.5)
|
||||
self.w_max = round(max_seq**0.5)
|
||||
|
||||
@torch.no_grad()
|
||||
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
|
||||
# extend pe
|
||||
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
|
||||
|
||||
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
|
||||
|
||||
# now we need to extend this to target_dim. for this we will use interpolation.
|
||||
# we will use torch.nn.functional.interpolate
|
||||
pe_as_2d = F.interpolate(
|
||||
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
|
||||
)
|
||||
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||
self.h_max, self.w_max = target_dim
|
||||
print("PE extended to", target_dim)
|
||||
|
||||
def pe_selection_index_based_on_dim(self, h, w):
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
|
||||
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
|
||||
starth = self.h_max // 2 - h_p // 2
|
||||
endh =starth + h_p
|
||||
startw = self.w_max // 2 - w_p // 2
|
||||
endw = startw + w_p
|
||||
original_pe_indexes = original_pe_indexes[
|
||||
starth:endh, startw:endw
|
||||
]
|
||||
return original_pe_indexes.flatten()
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
c = self.out_channels
|
||||
p = self.patch_size
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
(H + 1) // self.patch_size,
|
||||
self.patch_size,
|
||||
(W + 1) // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
return x
|
||||
|
||||
def apply_pos_embeds(self, x, h, w):
|
||||
h = (h + 1) // self.patch_size
|
||||
w = (w + 1) // self.patch_size
|
||||
max_dim = max(h, w)
|
||||
|
||||
cur_dim = self.h_max
|
||||
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
|
||||
|
||||
if max_dim > cur_dim:
|
||||
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
||||
cur_dim = max_dim
|
||||
|
||||
from_h = (cur_dim - h) // 2
|
||||
from_w = (cur_dim - w) // 2
|
||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
|
||||
# print(pe_indexes, pe_indexes.shape)
|
||||
|
||||
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
|
||||
x = self.apply_pos_embeds(x, h, w)
|
||||
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
|
||||
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
|
||||
|
||||
# process conditions for MMDiT Blocks
|
||||
c_seq = context # B, T_c, D_c
|
||||
t = timestep
|
||||
|
||||
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
||||
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
if len(self.double_layers) > 0:
|
||||
for layer in self.double_layers:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
cx = torch.cat([c, x], dim=1)
|
||||
for layer in self.single_layers:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
|
||||
|
||||
x = modulate(x, fshift, fscale)
|
||||
x = self.final_linear(x)
|
||||
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
|
||||
return x
|
||||
@ -272,4 +272,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
return key_map
|
||||
|
||||
@ -17,7 +17,7 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
|
||||
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@ -622,6 +622,17 @@ class SD3(BaseModel):
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * 0.3) * (1024 * 1024)
|
||||
|
||||
class AuraFlow(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=AuraMMDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableAudio1(BaseModel):
|
||||
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||
|
||||
@ -104,6 +104,19 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
return unet_config
|
||||
|
||||
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
|
||||
unet_config = {}
|
||||
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
||||
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
||||
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
|
||||
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["n_double_layers"] = double_layers
|
||||
unet_config["n_layers"] = double_layers + single_layers
|
||||
return unet_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
@ -238,6 +251,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
if unet_config is None:
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return supported_models_base.BASE(unet_config)
|
||||
@ -247,6 +262,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||
unet_key_prefix = "model.model."
|
||||
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
|
||||
unet_key_prefix = "model."
|
||||
else:
|
||||
unet_key_prefix = "model.diffusion_model."
|
||||
return unet_key_prefix
|
||||
@ -436,12 +453,19 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
return None
|
||||
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
out_sd = None
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
if num_blocks > 0:
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
else:
|
||||
return None
|
||||
|
||||
for k in sd_map:
|
||||
weight = state_dict.get(k, None)
|
||||
if weight is not None:
|
||||
|
||||
@ -60,6 +60,12 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
||||
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
model_options["disable_cfg1_optimization"] = True
|
||||
return model_options
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -142,6 +148,9 @@ class ModelPatcher(ModelManageable):
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
|
||||
@ -192,11 +192,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
|
||||
|
||||
def set_parameters(self, shift=1.0, timesteps=1000):
|
||||
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
|
||||
self.shift = shift
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.multiplier = multiplier
|
||||
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
@ -208,10 +209,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
return sigma * self.multiplier
|
||||
|
||||
def sigma(self, timestep):
|
||||
return time_snr_shift(self.shift, timestep / 1000)
|
||||
return time_snr_shift(self.shift, timestep / self.multiplier)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
|
||||
@ -46,8 +46,9 @@ class CLIPTextEncode:
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
|
||||
class ConditioningCombine:
|
||||
@classmethod
|
||||
@ -223,8 +224,9 @@ class ConditioningZeroOut:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
if "pooled_output" in d:
|
||||
d["pooled_output"] = torch.zeros_like(d["pooled_output"])
|
||||
pooled_output = d.get("pooled_output", None)
|
||||
if pooled_output is not None:
|
||||
d["pooled_output"] = torch.zeros_like(pooled_output)
|
||||
n = [torch.zeros_like(t[0]), d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
@ -1,22 +1,27 @@
|
||||
from comfy import sd1_clip
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
import comfy.t5
|
||||
import os
|
||||
from comfy import sd1_clip
|
||||
from comfy.component_model import files
|
||||
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||
|
||||
|
||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||
|
||||
|
||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
|
||||
@ -278,6 +278,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
|
||||
conds = [cond, uncond_]
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||
out = fn(args)
|
||||
|
||||
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
||||
|
||||
|
||||
|
||||
72
comfy/sd.py
72
comfy/sd.py
@ -28,37 +28,7 @@ from .t2i_adapter import adapter
|
||||
from .taesd import taesd
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
|
||||
|
||||
def load_model_weights(model, sd):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
m = set(m)
|
||||
unexpected_keys = set(u)
|
||||
|
||||
k = list(sd.keys())
|
||||
for x in k:
|
||||
if x not in unexpected_keys:
|
||||
w = sd.pop(x)
|
||||
del w
|
||||
if len(m) > 0:
|
||||
logging.warning("missing {}".format(m))
|
||||
return model
|
||||
|
||||
|
||||
def load_clip_weights(model, sd):
|
||||
k = list(sd.keys())
|
||||
for x in k:
|
||||
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
||||
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
||||
sd[y] = sd.pop(x)
|
||||
|
||||
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
|
||||
ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
||||
if ids.dtype == torch.float32:
|
||||
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
sd = utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
|
||||
return load_model_weights(model, sd)
|
||||
from .text_encoders import aura_t5
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
@ -136,7 +106,7 @@ class CLIP:
|
||||
def tokenize(self, text, return_word_ids=False):
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
if self.layer_idx is not None:
|
||||
@ -146,7 +116,15 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
out = {"cond": cond, "pooled_output": pooled}
|
||||
if len(o) > 2:
|
||||
for k in o[2]:
|
||||
out[k] = o[2][k]
|
||||
return out
|
||||
|
||||
if return_pooled:
|
||||
return cond, pooled
|
||||
return cond
|
||||
@ -447,9 +425,14 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
|
||||
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
dtype_t5 = weight.dtype
|
||||
if weight.shape[-1] == 4096:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
elif weight.shape[-1] == 2048:
|
||||
clip_target.clip = aura_t5.AuraT5Model
|
||||
clip_target.tokenizer = aura_t5.AuraT5Tokenizer
|
||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||
clip_target.clip = sa_t5.SAT5Model
|
||||
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||
@ -529,13 +512,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
if output_clipvision:
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
@ -586,30 +569,25 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
checkpoint = False
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
checkpoint = True
|
||||
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: # ldm or stable cascade
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
||||
else:
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
if new_sd is None:
|
||||
return None
|
||||
if new_sd is not None: #diffusers mmdit
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
else: # diffusers
|
||||
else: # diffusers unet
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import importlib.resources
|
||||
import logging
|
||||
import numbers
|
||||
import os
|
||||
import traceback
|
||||
import zipfile
|
||||
from importlib.abc import Traversable
|
||||
from typing import Tuple, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
@ -14,6 +16,7 @@ from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMi
|
||||
from . import clip_model
|
||||
from . import model_management
|
||||
from . import ops
|
||||
from .component_model import files
|
||||
from .component_model.files import get_path_as_dict, get_package_as_path
|
||||
|
||||
|
||||
@ -29,7 +32,58 @@ def gen_empty_tokens(special_tokens, length):
|
||||
output += [pad_token] * (length - len(output))
|
||||
return output
|
||||
|
||||
class SDClipModel(torch.nn.Module):
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
o = self.encode(to_encode)
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(0, sections):
|
||||
z = out[k:k+1]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
return r
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
@ -40,7 +94,7 @@ class SDClipModel(torch.nn.Module):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True): # clip-vit-base-patch32
|
||||
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
if special_tokens is None:
|
||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||
@ -63,6 +117,7 @@ class SDClipModel(torch.nn.Module):
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
self.return_attention_masks = return_attention_masks
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
@ -136,7 +191,7 @@ class SDClipModel(torch.nn.Module):
|
||||
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
|
||||
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
end_token = self.special_tokens.get("end", -1)
|
||||
for x in range(attention_mask.shape[0]):
|
||||
@ -145,7 +200,11 @@ class SDClipModel(torch.nn.Module):
|
||||
if tokens[x, y] == end_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
@ -153,7 +212,7 @@ class SDClipModel(torch.nn.Module):
|
||||
else:
|
||||
z = outputs[1].float()
|
||||
|
||||
if self.zero_out_masked and attention_mask is not None:
|
||||
if self.zero_out_masked:
|
||||
z *= attention_mask.unsqueeze(-1).float()
|
||||
|
||||
pooled_output = None
|
||||
@ -163,6 +222,13 @@ class SDClipModel(torch.nn.Module):
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
|
||||
extra = {}
|
||||
if self.return_attention_masks:
|
||||
extra["attention_mask"] = attention_mask
|
||||
|
||||
if len(extra) > 0:
|
||||
return z, pooled_output, extra
|
||||
|
||||
return z, pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
@ -374,10 +440,13 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
||||
tokenizer_path = files.get_package_as_path("comfy.sd1_tokenizer")
|
||||
if isinstance(tokenizer_path, Traversable):
|
||||
contextlib_path = importlib.resources.as_file(tokenizer_path)
|
||||
tokenizer_path = contextlib_path.__enter__()
|
||||
if not tokenizer_path.endswith(".model") and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
||||
# package based
|
||||
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
|
||||
self.tokenizer_class = tokenizer_class
|
||||
@ -395,6 +464,14 @@ class SDTokenizer:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
|
||||
if pad_token is not None:
|
||||
self.pad_token = pad_token
|
||||
elif pad_with_end:
|
||||
self.pad_token = self.end_token
|
||||
else:
|
||||
self.pad_token = 0
|
||||
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
self.additional_tokens: Tuple[str, ...] = ()
|
||||
@ -439,10 +516,6 @@ class SDTokenizer:
|
||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||
'''
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
@ -502,7 +575,7 @@ class SDTokenizer:
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
# start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
@ -515,9 +588,9 @@ class SDTokenizer:
|
||||
# fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||
@ -560,10 +633,16 @@ class SD1Tokenizer:
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if name is not None:
|
||||
self.clip_name = name
|
||||
self.clip = "{}".format(self.clip_name)
|
||||
else:
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
@ -578,8 +657,8 @@ class SD1ClipModel(torch.nn.Module):
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out, pooled
|
||||
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
@ -1,32 +1,38 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.t5
|
||||
from comfy import sd1_clip
|
||||
from comfy import sdxl_clip
|
||||
from transformers import T5TokenizerFast
|
||||
import comfy.t5
|
||||
import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
import logging
|
||||
from comfy.component_model import files
|
||||
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
|
||||
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
||||
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
@ -43,6 +49,7 @@ class SD3Tokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
@ -143,8 +150,10 @@ class SD3ClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
|
||||
return SD3ClipModel_
|
||||
|
||||
@ -7,6 +7,7 @@ from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
from .text_encoders import aura_t5
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -556,7 +557,29 @@ class StableAudio(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||
|
||||
class AuraFlow(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"cond_seq_dim": 2048,
|
||||
}
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.73,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.AuraFlow(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(aura_t5.AuraT5Tokenizer, aura_t5.AuraT5Model)
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
35
comfy/t5.py
35
comfy/t5.py
@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module):
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
"relu": torch.nn.functional.relu,
|
||||
}
|
||||
|
||||
class T5DenseActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.relu(self.wi(x))
|
||||
x = self.act(self.wi(x))
|
||||
# x = self.dropout(x)
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
hidden_gelu = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
# x = self.dropout(x)
|
||||
@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module):
|
||||
return x
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
|
||||
super().__init__()
|
||||
if ff_activation == "gelu_pytorch_tanh":
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations)
|
||||
elif ff_activation == "relu":
|
||||
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations)
|
||||
if gated_act:
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||
else:
|
||||
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module):
|
||||
return x, past_bias
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations))
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
||||
@ -183,11 +190,11 @@ class T5Block(torch.nn.Module):
|
||||
return x, past_bias
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
self.block = torch.nn.ModuleList(
|
||||
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0)), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
@ -216,7 +223,7 @@ class T5(torch.nn.Module):
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
model_dim = config_dict["d_model"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations)
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
"dense_act_fn": "relu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": false,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 12,
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
|
||||
0
comfy/text_encoders/__init__.py
Normal file
0
comfy/text_encoders/__init__.py
Normal file
28
comfy/text_encoders/aura_t5.py
Normal file
28
comfy/text_encoders/aura_t5.py
Normal file
@ -0,0 +1,28 @@
|
||||
from importlib import resources
|
||||
|
||||
from comfy import sd1_clip
|
||||
from .llama_tokenizer import LLAMATokenizer
|
||||
from .. import t5
|
||||
from ..component_model.files import get_path_as_dict
|
||||
|
||||
|
||||
class PT5XlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_pile_config_xl.json", package="comfy.text_encoders")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
|
||||
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
tokenizer_path = resources.files("comfy.text_encoders.t5_pile_tokenizer") / "tokenizer.model"
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||
|
||||
|
||||
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||
|
||||
|
||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||
22
comfy/text_encoders/llama_tokenizer.py
Normal file
22
comfy/text_encoders/llama_tokenizer.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
|
||||
class LLAMATokenizer:
|
||||
@staticmethod
|
||||
def from_pretrained(path):
|
||||
return LLAMATokenizer(path)
|
||||
|
||||
def __init__(self, tokenizer_path):
|
||||
import sentencepiece
|
||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
|
||||
self.end = self.tokenizer.eos_id()
|
||||
|
||||
def get_vocab(self):
|
||||
out = {}
|
||||
for i in range(self.tokenizer.get_piece_size()):
|
||||
out[self.tokenizer.id_to_piece(i)] = i
|
||||
return out
|
||||
|
||||
def __call__(self, string):
|
||||
out = self.tokenizer.encode(string)
|
||||
out += [self.end]
|
||||
return {"input_ids": out}
|
||||
22
comfy/text_encoders/t5_pile_config_xl.json
Normal file
22
comfy/text_encoders/t5_pile_config_xl.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 5120,
|
||||
"d_kv": 64,
|
||||
"d_model": 2048,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 2,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "umt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 32,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 1,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 32128
|
||||
}
|
||||
0
comfy/text_encoders/t5_pile_tokenizer/__init__.py
Normal file
0
comfy/text_encoders/t5_pile_tokenizer/__init__.py
Normal file
BIN
comfy/text_encoders/t5_pile_tokenizer/tokenizer.model
Normal file
BIN
comfy/text_encoders/t5_pile_tokenizer/tokenizer.model
Normal file
Binary file not shown.
@ -20,6 +20,8 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import checkpoint_pickle, interruption
|
||||
from .component_model import files
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||
from .component_model.queue_types import BinaryEventTypes
|
||||
from .execution_context import current_execution_context
|
||||
@ -374,6 +376,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||
return key_map
|
||||
|
||||
|
||||
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
|
||||
key_map = {}
|
||||
for i in range(n_layers):
|
||||
if i < n_double_layers:
|
||||
index = i
|
||||
prefix_from = "joint_transformer_blocks"
|
||||
prefix_to = "{}double_layers".format(output_prefix)
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w2q.weight",
|
||||
"attn.to_k.weight": "attn.w2k.weight",
|
||||
"attn.to_v.weight": "attn.w2v.weight",
|
||||
"attn.to_out.0.weight": "attn.w2o.weight",
|
||||
"attn.add_q_proj.weight": "attn.w1q.weight",
|
||||
"attn.add_k_proj.weight": "attn.w1k.weight",
|
||||
"attn.add_v_proj.weight": "attn.w1v.weight",
|
||||
"attn.to_add_out.weight": "attn.w1o.weight",
|
||||
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
||||
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
||||
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
||||
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
||||
"norm1.linear.weight": "modX.1.weight",
|
||||
"norm1_context.linear.weight": "modC.1.weight",
|
||||
}
|
||||
else:
|
||||
index = i - n_double_layers
|
||||
prefix_from = "single_transformer_blocks"
|
||||
prefix_to = "{}single_layers".format(output_prefix)
|
||||
|
||||
block_map = {
|
||||
"attn.to_q.weight": "attn.w1q.weight",
|
||||
"attn.to_k.weight": "attn.w1k.weight",
|
||||
"attn.to_v.weight": "attn.w1v.weight",
|
||||
"attn.to_out.0.weight": "attn.w1o.weight",
|
||||
"norm1.linear.weight": "modCX.1.weight",
|
||||
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
||||
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
||||
"ff.out_projection.weight": "mlp.c_proj.weight"
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
||||
|
||||
MAP_BASIC = {
|
||||
("positional_encoding", "pos_embed.pos_embed"),
|
||||
("register_tokens", "register_tokens"),
|
||||
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
||||
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
||||
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
||||
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
||||
("cond_seq_linear.weight", "context_embedder.weight"),
|
||||
("init_x_linear.weight", "pos_embed.proj.weight"),
|
||||
("init_x_linear.bias", "pos_embed.proj.bias"),
|
||||
("final_linear.weight", "proj_out.weight"),
|
||||
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
if len(k) > 2:
|
||||
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||
else:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
@ -675,8 +747,9 @@ class ProgressBar:
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
|
||||
@_deprecate_method(version="1.0.0", message="The root project directory isn't valid when the application is installed as a package. Use os.getcwd() instead.")
|
||||
def get_project_root() -> str:
|
||||
return os.path.join(os.path.dirname(__file__), "..")
|
||||
return files.get_package_as_path("comfy")
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@ -1599,7 +1599,7 @@ export class ComfyApp {
|
||||
if (json) {
|
||||
const workflow = JSON.parse(json);
|
||||
const workflowName = getStorageValue("Comfy.PreviousWorkflow");
|
||||
await this.loadGraphData(workflow, true, workflowName);
|
||||
await this.loadGraphData(workflow, true, true, workflowName);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
@ -182,6 +182,11 @@ export class ComfyWorkflowsMenu {
|
||||
* @param {ComfyWorkflow} workflow
|
||||
*/
|
||||
async function sendToWorkflow(img, workflow) {
|
||||
const openWorkflow = app.workflowManager.openWorkflows.find((w) => w.path === workflow.path);
|
||||
if (openWorkflow) {
|
||||
workflow = openWorkflow;
|
||||
}
|
||||
|
||||
await workflow.load();
|
||||
let options = [];
|
||||
const nodes = app.graph.computeExecutionOrder(false);
|
||||
@ -214,7 +219,8 @@ export class ComfyWorkflowsMenu {
|
||||
nodeType.prototype["getExtraMenuOptions"] = function (_, options) {
|
||||
const r = getExtraMenuOptions?.apply?.(this, arguments);
|
||||
|
||||
if (app.ui.settings.getSettingValue("Comfy.UseNewMenu", false) === true) {
|
||||
const setting = app.ui.settings.getSettingValue("Comfy.UseNewMenu", false);
|
||||
if (setting && setting != "Disabled") {
|
||||
const t = /** @type { {imageIndex?: number, overIndex?: number, imgs: string[]} } */ /** @type {any} */ (this);
|
||||
let img;
|
||||
if (t.imageIndex != null) {
|
||||
|
||||
@ -41,7 +41,7 @@ body {
|
||||
background-color: var(--bg-color);
|
||||
color: var(--fg-color);
|
||||
grid-template-columns: auto 1fr auto;
|
||||
grid-template-rows: auto auto 1fr auto;
|
||||
grid-template-rows: auto 1fr auto;
|
||||
min-height: -webkit-fill-available;
|
||||
max-height: -webkit-fill-available;
|
||||
min-width: -webkit-fill-available;
|
||||
@ -49,32 +49,37 @@ body {
|
||||
}
|
||||
|
||||
.comfyui-body-top {
|
||||
order: 0;
|
||||
order: -5;
|
||||
grid-column: 1/-1;
|
||||
z-index: 10;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.comfyui-body-left {
|
||||
order: 1;
|
||||
order: -4;
|
||||
z-index: 10;
|
||||
display: flex;
|
||||
}
|
||||
|
||||
#graph-canvas {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
order: 2;
|
||||
grid-column: 1/-1;
|
||||
order: -3;
|
||||
}
|
||||
|
||||
.comfyui-body-right {
|
||||
order: 3;
|
||||
order: -2;
|
||||
z-index: 10;
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.comfyui-body-bottom {
|
||||
order: 4;
|
||||
order: -1;
|
||||
grid-column: 1/-1;
|
||||
z-index: 10;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.comfy-multiline-input {
|
||||
@ -408,8 +413,12 @@ dialog::backdrop {
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
.comfy-dialog.comfyui-dialog {
|
||||
.comfy-dialog.comfyui-dialog.comfy-modal {
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.comfy-dialog.comfy-modal {
|
||||
|
||||
@ -20,7 +20,7 @@ class EmptyLatentAudio:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "_for_testing/audio"
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds):
|
||||
batch_size = 1
|
||||
@ -35,7 +35,7 @@ class VAEEncodeAudio:
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "_for_testing/audio"
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def encode(self, vae, audio):
|
||||
sample_rate = audio["sample_rate"]
|
||||
@ -55,7 +55,7 @@ class VAEDecodeAudio:
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing/audio"
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def decode(self, vae, samples):
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
@ -134,7 +134,7 @@ class SaveAudio:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing/audio"
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
import torchaudio # pylint: disable=import-error
|
||||
@ -199,7 +199,7 @@ class LoadAudio:
|
||||
]
|
||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||
|
||||
CATEGORY = "_for_testing/audio"
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
FUNCTION = "load"
|
||||
@ -209,7 +209,6 @@ class LoadAudio:
|
||||
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
multiplier = 1.0
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ class ModelSamplingSD3:
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, shift):
|
||||
def patch(self, model, shift, multiplier=1000):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||
@ -157,10 +157,22 @@ class ModelSamplingSD3:
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
model_sampling.set_parameters(shift=shift, multiplier=multiplier)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class ModelSamplingAuraFlow(ModelSamplingSD3):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
}}
|
||||
|
||||
FUNCTION = "patch_aura"
|
||||
|
||||
def patch_aura(self, model, shift):
|
||||
return self.patch(model, shift, multiplier=1.0)
|
||||
|
||||
class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -276,5 +288,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingContinuousV": ModelSamplingContinuousV,
|
||||
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
||||
"ModelSamplingSD3": ModelSamplingSD3,
|
||||
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
}
|
||||
|
||||
@ -5,6 +5,8 @@ torchsde>=0.2.6
|
||||
einops>=0.6.0
|
||||
open-clip-torch>=2.24.0
|
||||
transformers>=4.29.1
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
peft
|
||||
torchinfo
|
||||
safetensors>=0.4.2
|
||||
|
||||
1
setup.py
1
setup.py
@ -192,6 +192,7 @@ package_data = [
|
||||
't5_tokenizer/*',
|
||||
'**/*.json',
|
||||
'**/*.yaml',
|
||||
'**/*.model'
|
||||
]
|
||||
if not is_editable:
|
||||
package_data.append('comfy/web/**/*')
|
||||
|
||||
@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from aiohttp import ClientSession, ClientConnectorError
|
||||
from aiohttp import ClientSession
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
@ -132,13 +132,11 @@ async def test_basic_queue_worker_with_health_check():
|
||||
health_check_port = 9090
|
||||
|
||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||
# Test health check
|
||||
health_check_url = f"http://localhost:{health_check_port}/health"
|
||||
|
||||
health_check_ok = await check_health(health_check_url)
|
||||
assert health_check_ok, "Health check server did not start properly"
|
||||
|
||||
# Test the actual worker functionality
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||
await distributed_queue.init()
|
||||
@ -153,53 +151,5 @@ async def test_basic_queue_worker_with_health_check():
|
||||
|
||||
await distributed_queue.close()
|
||||
|
||||
# Test that the health check server is stopped after the worker is closed
|
||||
health_check_stopped = not await check_health(health_check_url, max_retries=1)
|
||||
assert health_check_stopped, "Health check server did not stop properly"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_port_conflict():
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
health_check_port = 9090
|
||||
|
||||
# Start a simple server to occupy the health check port
|
||||
from aiohttp import web
|
||||
async def dummy_handler(request):
|
||||
return web.Response(text="Dummy")
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get('/', dummy_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, '0.0.0.0', health_check_port)
|
||||
await site.start()
|
||||
|
||||
try:
|
||||
# Now try to start the DistributedPromptWorker
|
||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||
# The health check should be disabled, but the worker should still function
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||
await distributed_queue.init()
|
||||
|
||||
queue_item = create_test_prompt()
|
||||
res = await distributed_queue.put_async(queue_item)
|
||||
|
||||
assert res.item_id == queue_item.prompt_id
|
||||
assert len(res.outputs) == 1
|
||||
assert res.status is not None
|
||||
assert res.status.status_str == "success"
|
||||
|
||||
await distributed_queue.close()
|
||||
|
||||
# The original server should still be running
|
||||
async with ClientSession() as session:
|
||||
async with session.get(f"http://localhost:{health_check_port}") as response:
|
||||
assert response.status == 200
|
||||
assert await response.text() == "Dummy"
|
||||
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
from comfy.cli_args_types import Configuration
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
||||
from comfy.model_downloader_types import CivitFile
|
||||
@ -139,9 +138,7 @@ _workflows = {
|
||||
@pytest.fixture(scope="module", autouse=False)
|
||||
@pytest.mark.asyncio
|
||||
async def client(tmp_path_factory) -> EmbeddedComfyClient:
|
||||
config = Configuration()
|
||||
config.cwd = str(tmp_path_factory.mktemp("comfy_test_cwd"))
|
||||
async with EmbeddedComfyClient(config) as client:
|
||||
async with EmbeddedComfyClient() as client:
|
||||
yield client
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user