mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
3d1d833e6f
9
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
9
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -7,9 +7,12 @@ body:
|
|||||||
value: |
|
value: |
|
||||||
Before submitting a **Bug Report**, please ensure the following:
|
Before submitting a **Bug Report**, please ensure the following:
|
||||||
|
|
||||||
**1:** You are running the latest version of ComfyUI.
|
- **1:** You are running the latest version of ComfyUI.
|
||||||
**2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
||||||
**3:** This is an actual bug in ComfyUI, not just a support question and not caused by an custom node. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||||
|
`--disable-all-custom-nodes` command line argument.
|
||||||
|
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||||
|
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
||||||
- type: textarea
|
- type: textarea
|
||||||
|
|||||||
109
.github/workflows/stable-release.yml
vendored
Normal file
109
.github/workflows/stable-release.yml
vendored
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
|
||||||
|
name: "Release Stable Version"
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
package_comfy_windows:
|
||||||
|
permissions:
|
||||||
|
contents: "write"
|
||||||
|
packages: "write"
|
||||||
|
pull-requests: "read"
|
||||||
|
runs-on: windows-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python_version: [3.11.8]
|
||||||
|
cuda_version: [121]
|
||||||
|
steps:
|
||||||
|
- name: Calculate Minor Version
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
# Extract the minor version from the Python version
|
||||||
|
MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2)
|
||||||
|
echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python_version }}
|
||||||
|
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
persist-credentials: false
|
||||||
|
- shell: bash
|
||||||
|
run: |
|
||||||
|
echo "@echo off
|
||||||
|
call update_comfyui.bat nopause
|
||||||
|
echo -
|
||||||
|
echo This will try to update pytorch and all python dependencies.
|
||||||
|
echo -
|
||||||
|
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||||
|
echo -
|
||||||
|
pause
|
||||||
|
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
pause" > update_comfyui_and_python_dependencies.bat
|
||||||
|
|
||||||
|
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||||
|
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||||
|
echo installed basic
|
||||||
|
ls -lah temp_wheel_dir
|
||||||
|
mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps
|
||||||
|
mv cu${{ matrix.cuda_version }}_python_deps ../
|
||||||
|
mv update_comfyui_and_python_dependencies.bat ../
|
||||||
|
cd ..
|
||||||
|
pwd
|
||||||
|
ls
|
||||||
|
|
||||||
|
cp -r ComfyUI ComfyUI_copy
|
||||||
|
curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip
|
||||||
|
unzip python_embeded.zip -d python_embeded
|
||||||
|
cd python_embeded
|
||||||
|
echo ${{ env.MINOR_VERSION }}
|
||||||
|
echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth
|
||||||
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
|
./python.exe get-pip.py
|
||||||
|
./python.exe --version
|
||||||
|
echo "Pip version:"
|
||||||
|
./python.exe -m pip --version
|
||||||
|
|
||||||
|
set PATH=$PWD/Scripts:$PATH
|
||||||
|
echo $PATH
|
||||||
|
./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/*
|
||||||
|
sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
git clone https://github.com/comfyanonymous/taesd
|
||||||
|
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
|
mkdir ComfyUI_windows_portable
|
||||||
|
mv python_embeded ComfyUI_windows_portable
|
||||||
|
mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
|
||||||
|
|
||||||
|
cd ComfyUI_windows_portable
|
||||||
|
|
||||||
|
mkdir update
|
||||||
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
|
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
|
mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
|
||||||
|
|
||||||
|
cd ComfyUI_windows_portable
|
||||||
|
python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
|
||||||
|
|
||||||
|
ls
|
||||||
|
|
||||||
|
- name: Upload binaries to release
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
repo_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
file: ComfyUI_windows_portable_nvidia.7z
|
||||||
|
tag: ${{ github.ref }}
|
||||||
|
overwrite: true
|
||||||
|
|
||||||
15
.github/workflows/test-browser.yml
vendored
15
.github/workflows/test-browser.yml
vendored
@ -41,7 +41,7 @@ jobs:
|
|||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- name: Start ComfyUI server
|
- name: Start ComfyUI server
|
||||||
run: |
|
run: |
|
||||||
python main.py --cpu &
|
python main.py --cpu 2>&1 | tee console_output.log &
|
||||||
wait-for-it --service 127.0.0.1:8188 -t 600
|
wait-for-it --service 127.0.0.1:8188 -t 600
|
||||||
working-directory: ComfyUI
|
working-directory: ComfyUI
|
||||||
- name: Install ComfyUI_frontend dependencies
|
- name: Install ComfyUI_frontend dependencies
|
||||||
@ -54,9 +54,22 @@ jobs:
|
|||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
run: npx playwright test
|
run: npx playwright test
|
||||||
working-directory: ComfyUI_frontend
|
working-directory: ComfyUI_frontend
|
||||||
|
- name: Check for unhandled exceptions in server log
|
||||||
|
run: |
|
||||||
|
if grep -qE "Exception|Error" console_output.log; then
|
||||||
|
echo "Unhandled exception/error found in server log."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
working-directory: ComfyUI
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
||||||
name: playwright-report
|
name: playwright-report
|
||||||
path: ComfyUI_frontend/playwright-report/
|
path: ComfyUI_frontend/playwright-report/
|
||||||
retention-days: 30
|
retention-days: 30
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: console-output
|
||||||
|
path: ComfyUI/console_output.log
|
||||||
|
retention-days: 30
|
||||||
|
|||||||
@ -42,6 +42,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
|
|||||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||||
|
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: will never download anything.
|
||||||
|
|||||||
@ -10,10 +10,51 @@ from ..ldm.modules.diffusionmodules.util import (
|
|||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..ldm.modules.attention import SpatialTransformer
|
from ..ldm.modules.attention import SpatialTransformer, optimized_attention
|
||||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||||
from ..ldm.util import exists
|
from ..ldm.util import exists
|
||||||
from .. import ops
|
from .. import ops
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedAttention(nn.Module):
|
||||||
|
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = nhead
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
|
||||||
|
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.in_proj(x)
|
||||||
|
q, k, v = x.split(self.c, dim=2)
|
||||||
|
out = optimized_attention(q, k, v, self.heads)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlockUnionControlnet(nn.Module):
|
||||||
|
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
|
||||||
|
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
|
||||||
|
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def attention(self, x: torch.Tensor):
|
||||||
|
return self.attn(x)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = x + self.attention(self.ln_1(x))
|
||||||
|
x = x + self.mlp(self.ln_2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
class ControlledUnetModel(UNetModel):
|
class ControlledUnetModel(UNetModel):
|
||||||
#implemented in the ldm unet
|
#implemented in the ldm unet
|
||||||
@ -53,6 +94,7 @@ class ControlNet(nn.Module):
|
|||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
transformer_depth_output=None,
|
transformer_depth_output=None,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
union_controlnet_num_control_type=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=ops.disable_weight_init,
|
operations=ops.disable_weight_init,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -280,6 +322,65 @@ class ControlNet(nn.Module):
|
|||||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
|
if union_controlnet_num_control_type is not None:
|
||||||
|
self.num_control_type = union_controlnet_num_control_type
|
||||||
|
num_trans_channel = 320
|
||||||
|
num_trans_head = 8
|
||||||
|
num_trans_layer = 1
|
||||||
|
num_proj_channel = 320
|
||||||
|
# task_scale_factor = num_trans_channel ** 0.5
|
||||||
|
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
|
||||||
|
|
||||||
|
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
|
||||||
|
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
|
||||||
|
#-----------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
control_add_embed_dim = 256
|
||||||
|
class ControlAddEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_control_type = num_control_type
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
|
||||||
|
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
def forward(self, control_type, dtype, device):
|
||||||
|
c_type = torch.zeros((self.num_control_type,), device=device)
|
||||||
|
c_type[control_type] = 1.0
|
||||||
|
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
|
||||||
|
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
|
||||||
|
|
||||||
|
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
self.task_embedding = None
|
||||||
|
self.control_add_embedding = None
|
||||||
|
|
||||||
|
def union_controlnet_merge(self, hint, control_type, emb, context):
|
||||||
|
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
|
||||||
|
inputs = []
|
||||||
|
condition_list = []
|
||||||
|
|
||||||
|
for idx in range(min(1, len(control_type))):
|
||||||
|
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
|
||||||
|
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||||
|
if idx < len(control_type):
|
||||||
|
feat_seq += self.task_embedding[control_type[idx]]
|
||||||
|
|
||||||
|
inputs.append(feat_seq.unsqueeze(1))
|
||||||
|
condition_list.append(controlnet_cond)
|
||||||
|
|
||||||
|
x = torch.cat(inputs, dim=1)
|
||||||
|
x = self.transformer_layes(x)
|
||||||
|
controlnet_cond_fuser = None
|
||||||
|
for idx in range(len(control_type)):
|
||||||
|
alpha = self.spatial_ch_projs(x[:, idx])
|
||||||
|
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
o = condition_list[idx] + alpha
|
||||||
|
if controlnet_cond_fuser is None:
|
||||||
|
controlnet_cond_fuser = o
|
||||||
|
else:
|
||||||
|
controlnet_cond_fuser += o
|
||||||
|
return controlnet_cond_fuser
|
||||||
|
|
||||||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||||
|
|
||||||
@ -287,6 +388,17 @@ class ControlNet(nn.Module):
|
|||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
guided_hint = None
|
||||||
|
if self.control_add_embedding is not None: #Union Controlnet
|
||||||
|
control_type = kwargs.get("control_type", [])
|
||||||
|
|
||||||
|
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
|
||||||
|
if len(control_type) > 0:
|
||||||
|
if len(hint.shape) < 5:
|
||||||
|
hint = hint.unsqueeze(dim=0)
|
||||||
|
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
|
||||||
|
|
||||||
|
if guided_hint is None:
|
||||||
guided_hint = self.input_hint_block(hint, emb, context)
|
guided_hint = self.input_hint_block(hint, emb, context)
|
||||||
|
|
||||||
out_output = []
|
out_output = []
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from .component_model import files
|
||||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
@ -30,9 +31,17 @@ def clip_preprocess(image, size=224):
|
|||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config: dict | str):
|
||||||
|
if isinstance(json_config, dict):
|
||||||
|
config = json_config
|
||||||
|
elif json_config is not None and isinstance(json_config, str):
|
||||||
|
if json_config.startswith("{"):
|
||||||
|
config = json.loads(json_config)
|
||||||
|
else:
|
||||||
with open(json_config) as f:
|
with open(json_config) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"json_config had invalid value={json_config}")
|
||||||
|
|
||||||
self.load_device = model_management.text_encoder_device()
|
self.load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
@ -88,12 +97,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
if convert_keys:
|
if convert_keys:
|
||||||
sd = convert_to_transformers(sd, prefix)
|
sd = convert_to_transformers(sd, prefix)
|
||||||
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
||||||
# todo: fix the importlib issue here
|
json_config = files.get_path_as_dict(None, "clip_vision_config_g.json")
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
|
||||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = files.get_path_as_dict(None, "clip_vision_config_h.json")
|
||||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -85,7 +85,8 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|||||||
|
|
||||||
def cleanup_temp():
|
def cleanup_temp():
|
||||||
try:
|
try:
|
||||||
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
folder_paths.get_temp_directory()
|
||||||
|
temp_dir = folder_paths.get_temp_directory()
|
||||||
if os.path.exists(temp_dir):
|
if os.path.exists(temp_dir):
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
except NameError:
|
except NameError:
|
||||||
@ -115,7 +116,7 @@ async def main():
|
|||||||
|
|
||||||
# configure extra model paths earlier
|
# configure extra model paths earlier
|
||||||
try:
|
try:
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.getcwd(), "extra_model_paths.yaml")
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config(extra_model_paths_config_path)
|
load_extra_path_config(extra_model_paths_config_path)
|
||||||
except NameError:
|
except NameError:
|
||||||
|
|||||||
@ -439,6 +439,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
info['name'] = node_class
|
info['name'] = node_class
|
||||||
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||||
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
|
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
|
||||||
|
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
|
||||||
info['category'] = 'sd'
|
info['category'] = 'sd'
|
||||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||||
info['output_node'] = True
|
info['output_node'] = True
|
||||||
@ -845,18 +846,9 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_output_path(cls, subfolder: str | None = None, filename: str | None = None):
|
|
||||||
paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""]
|
|
||||||
return os.path.join(os.path.dirname(os.path.realpath(__file__)), *paths)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_upload_dir(cls) -> str:
|
def get_upload_dir(cls) -> str:
|
||||||
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../input")
|
return folder_paths.get_input_directory()
|
||||||
|
|
||||||
if not os.path.exists(upload_dir):
|
|
||||||
os.makedirs(upload_dir)
|
|
||||||
return upload_dir
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_too_busy_queue_size(cls):
|
def get_too_busy_queue_size(cls):
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_
|
|||||||
config: dict | None = None
|
config: dict | None = None
|
||||||
|
|
||||||
if config_dict_or_path is None:
|
if config_dict_or_path is None:
|
||||||
config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package)
|
config_dict_or_path = config_path_inside_package
|
||||||
|
|
||||||
if isinstance(config_dict_or_path, str):
|
if isinstance(config_dict_or_path, str):
|
||||||
if config_dict_or_path.startswith("{"):
|
if config_dict_or_path.startswith("{"):
|
||||||
|
|||||||
@ -412,6 +412,12 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if k in controlnet_data:
|
if k in controlnet_data:
|
||||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||||
|
|
||||||
|
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
|
||||||
|
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
|
||||||
|
for k in list(controlnet_data.keys()):
|
||||||
|
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
||||||
|
new_sd[new_k] = controlnet_data.pop(k)
|
||||||
|
|
||||||
leftover_keys = controlnet_data.keys()
|
leftover_keys = controlnet_data.keys()
|
||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
|
|||||||
479
comfy/ldm/aura/mmdit.py
Normal file
479
comfy/ldm/aura/mmdit.py
Normal file
@ -0,0 +1,479 @@
|
|||||||
|
#AuraFlow MMDiT
|
||||||
|
#Originally written by the AuraFlow Authors
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def find_multiple(n: int, k: int) -> int:
|
||||||
|
if n % k == 0:
|
||||||
|
return n
|
||||||
|
return n + k - (n % k)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if hidden_dim is None:
|
||||||
|
hidden_dim = 4 * dim
|
||||||
|
|
||||||
|
n_hidden = int(2 * hidden_dim / 3)
|
||||||
|
n_hidden = find_multiple(n_hidden, 256)
|
||||||
|
|
||||||
|
self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
||||||
|
self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
||||||
|
self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
||||||
|
x = self.c_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadLayerNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
|
||||||
|
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
mean = hidden_states.mean(-1, keepdim=True)
|
||||||
|
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = (hidden_states - mean) * torch.rsqrt(
|
||||||
|
variance + self.variance_epsilon
|
||||||
|
)
|
||||||
|
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
class SingleAttention(nn.Module):
|
||||||
|
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
# this is for cond
|
||||||
|
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.q_norm1 = (
|
||||||
|
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||||
|
if mh_qknorm
|
||||||
|
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.k_norm1 = (
|
||||||
|
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||||
|
if mh_qknorm
|
||||||
|
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
#@torch.compile()
|
||||||
|
def forward(self, c):
|
||||||
|
|
||||||
|
bsz, seqlen1, _ = c.shape
|
||||||
|
|
||||||
|
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||||
|
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
|
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
|
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
|
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||||
|
|
||||||
|
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||||
|
c = self.w1o(output)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleAttention(nn.Module):
|
||||||
|
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
# this is for cond
|
||||||
|
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# this is for x
|
||||||
|
self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.q_norm1 = (
|
||||||
|
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||||
|
if mh_qknorm
|
||||||
|
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.k_norm1 = (
|
||||||
|
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||||
|
if mh_qknorm
|
||||||
|
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_norm2 = (
|
||||||
|
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||||
|
if mh_qknorm
|
||||||
|
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.k_norm2 = (
|
||||||
|
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||||
|
if mh_qknorm
|
||||||
|
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#@torch.compile()
|
||||||
|
def forward(self, c, x):
|
||||||
|
|
||||||
|
bsz, seqlen1, _ = c.shape
|
||||||
|
bsz, seqlen2, _ = x.shape
|
||||||
|
seqlen = seqlen1 + seqlen2
|
||||||
|
|
||||||
|
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||||
|
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
|
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
|
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||||
|
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
|
||||||
|
|
||||||
|
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
|
||||||
|
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||||
|
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
|
||||||
|
|
||||||
|
# concat all
|
||||||
|
q, k, v = (
|
||||||
|
torch.cat([cq, xq], dim=1),
|
||||||
|
torch.cat([ck, xk], dim=1),
|
||||||
|
torch.cat([cv, xv], dim=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||||
|
|
||||||
|
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||||
|
c = self.w1o(c)
|
||||||
|
x = self.w2o(x)
|
||||||
|
|
||||||
|
return c, x
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiTBlock(nn.Module):
|
||||||
|
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
if not is_last:
|
||||||
|
self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.modC = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.modC = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.modX = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.is_last = is_last
|
||||||
|
|
||||||
|
#@torch.compile()
|
||||||
|
def forward(self, c, x, global_cond, **kwargs):
|
||||||
|
|
||||||
|
cres, xres = c, x
|
||||||
|
|
||||||
|
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
|
||||||
|
self.modC(global_cond).chunk(6, dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
|
||||||
|
|
||||||
|
# xpath
|
||||||
|
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
|
||||||
|
self.modX(global_cond).chunk(6, dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
c, x = self.attn(c, x)
|
||||||
|
|
||||||
|
|
||||||
|
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||||
|
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
|
||||||
|
c = cres + c
|
||||||
|
|
||||||
|
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
|
||||||
|
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
|
||||||
|
x = xres + x
|
||||||
|
|
||||||
|
return c, x
|
||||||
|
|
||||||
|
class DiTBlock(nn.Module):
|
||||||
|
# like MMDiTBlock, but it only has X
|
||||||
|
def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.modCX = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
#@torch.compile()
|
||||||
|
def forward(self, cx, global_cond, **kwargs):
|
||||||
|
cxres = cx
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||||
|
global_cond
|
||||||
|
).chunk(6, dim=1)
|
||||||
|
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||||
|
cx = self.attn(cx)
|
||||||
|
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||||
|
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||||
|
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||||
|
|
||||||
|
cx = cxres + cx
|
||||||
|
|
||||||
|
return cx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedder(nn.Module):
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def timestep_embedding(t, dim, max_period=10000):
|
||||||
|
half = dim // 2
|
||||||
|
freqs = 1000 * torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
||||||
|
).to(t.device)
|
||||||
|
args = t[:, None] * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat(
|
||||||
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||||
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
#@torch.compile()
|
||||||
|
def forward(self, t, dtype):
|
||||||
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||||
|
t_emb = self.mlp(t_freq)
|
||||||
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=4,
|
||||||
|
out_channels=4,
|
||||||
|
patch_size=2,
|
||||||
|
dim=3072,
|
||||||
|
n_layers=36,
|
||||||
|
n_double_layers=4,
|
||||||
|
n_heads=12,
|
||||||
|
global_conddim=3072,
|
||||||
|
cond_seq_dim=2048,
|
||||||
|
max_seq=32 * 32,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.cond_seq_linear = operations.Linear(
|
||||||
|
cond_seq_dim, dim, bias=False, dtype=dtype, device=device
|
||||||
|
) # linear for something like text sequence.
|
||||||
|
self.init_x_linear = operations.Linear(
|
||||||
|
patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
|
||||||
|
) # init linear for patchified image.
|
||||||
|
|
||||||
|
self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
|
||||||
|
self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.double_layers = nn.ModuleList([])
|
||||||
|
self.single_layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
|
||||||
|
for idx in range(n_double_layers):
|
||||||
|
self.double_layers.append(
|
||||||
|
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx in range(n_double_layers, n_layers):
|
||||||
|
self.single_layers.append(
|
||||||
|
DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.final_linear = operations.Linear(
|
||||||
|
dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.modF = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.n_double_layers = n_double_layers
|
||||||
|
self.n_layers = n_layers
|
||||||
|
|
||||||
|
self.h_max = round(max_seq**0.5)
|
||||||
|
self.w_max = round(max_seq**0.5)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
|
||||||
|
# extend pe
|
||||||
|
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
|
||||||
|
|
||||||
|
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
|
||||||
|
|
||||||
|
# now we need to extend this to target_dim. for this we will use interpolation.
|
||||||
|
# we will use torch.nn.functional.interpolate
|
||||||
|
pe_as_2d = F.interpolate(
|
||||||
|
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
|
||||||
|
)
|
||||||
|
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||||
|
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||||
|
self.h_max, self.w_max = target_dim
|
||||||
|
print("PE extended to", target_dim)
|
||||||
|
|
||||||
|
def pe_selection_index_based_on_dim(self, h, w):
|
||||||
|
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||||
|
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
|
||||||
|
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
|
||||||
|
starth = self.h_max // 2 - h_p // 2
|
||||||
|
endh =starth + h_p
|
||||||
|
startw = self.w_max // 2 - w_p // 2
|
||||||
|
endw = startw + w_p
|
||||||
|
original_pe_indexes = original_pe_indexes[
|
||||||
|
starth:endh, startw:endw
|
||||||
|
]
|
||||||
|
return original_pe_indexes.flatten()
|
||||||
|
|
||||||
|
def unpatchify(self, x, h, w):
|
||||||
|
c = self.out_channels
|
||||||
|
p = self.patch_size
|
||||||
|
|
||||||
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||||
|
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||||
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def patchify(self, x):
|
||||||
|
B, C, H, W = x.size()
|
||||||
|
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||||
|
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||||
|
|
||||||
|
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
(H + 1) // self.patch_size,
|
||||||
|
self.patch_size,
|
||||||
|
(W + 1) // self.patch_size,
|
||||||
|
self.patch_size,
|
||||||
|
)
|
||||||
|
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_pos_embeds(self, x, h, w):
|
||||||
|
h = (h + 1) // self.patch_size
|
||||||
|
w = (w + 1) // self.patch_size
|
||||||
|
max_dim = max(h, w)
|
||||||
|
|
||||||
|
cur_dim = self.h_max
|
||||||
|
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
if max_dim > cur_dim:
|
||||||
|
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
||||||
|
cur_dim = max_dim
|
||||||
|
|
||||||
|
from_h = (cur_dim - h) // 2
|
||||||
|
from_w = (cur_dim - w) // 2
|
||||||
|
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||||
|
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
|
# patchify x, add PE
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
|
||||||
|
# print(pe_indexes, pe_indexes.shape)
|
||||||
|
|
||||||
|
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
|
||||||
|
x = self.apply_pos_embeds(x, h, w)
|
||||||
|
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
|
||||||
|
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
# process conditions for MMDiT Blocks
|
||||||
|
c_seq = context # B, T_c, D_c
|
||||||
|
t = timestep
|
||||||
|
|
||||||
|
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
||||||
|
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
|
||||||
|
|
||||||
|
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||||
|
|
||||||
|
if len(self.double_layers) > 0:
|
||||||
|
for layer in self.double_layers:
|
||||||
|
c, x = layer(c, x, global_cond, **kwargs)
|
||||||
|
|
||||||
|
if len(self.single_layers) > 0:
|
||||||
|
c_len = c.size(1)
|
||||||
|
cx = torch.cat([c, x], dim=1)
|
||||||
|
for layer in self.single_layers:
|
||||||
|
cx = layer(cx, global_cond, **kwargs)
|
||||||
|
|
||||||
|
x = cx[:, c_len:]
|
||||||
|
|
||||||
|
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
|
||||||
|
|
||||||
|
x = modulate(x, fshift, fscale)
|
||||||
|
x = self.final_linear(x)
|
||||||
|
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
|
||||||
|
return x
|
||||||
@ -272,4 +272,12 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
if isinstance(model, model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||||
|
diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
to = diffusers_keys[k]
|
||||||
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
|||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
|
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -622,6 +622,17 @@ class SD3(BaseModel):
|
|||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||||
return (area * 0.3) * (1024 * 1024)
|
return (area * 0.3) * (1024 * 1024)
|
||||||
|
|
||||||
|
class AuraFlow(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=AuraMMDiT)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class StableAudio1(BaseModel):
|
class StableAudio1(BaseModel):
|
||||||
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||||
|
|||||||
@ -104,6 +104,19 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
unet_config["audio_model"] = "dit1.0"
|
unet_config["audio_model"] = "dit1.0"
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
|
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
|
||||||
|
unet_config = {}
|
||||||
|
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
||||||
|
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
||||||
|
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
|
||||||
|
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
|
||||||
|
unet_config["n_double_layers"] = double_layers
|
||||||
|
unet_config["n_layers"] = double_layers + single_layers
|
||||||
|
return unet_config
|
||||||
|
|
||||||
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
|
return None
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
@ -238,6 +251,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
|||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||||
|
if unet_config is None:
|
||||||
|
return None
|
||||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return supported_models_base.BASE(unet_config)
|
return supported_models_base.BASE(unet_config)
|
||||||
@ -247,6 +262,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
if "model.model.postprocess_conv.weight" in state_dict: #audio models
|
||||||
unet_key_prefix = "model.model."
|
unet_key_prefix = "model.model."
|
||||||
|
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
|
||||||
|
unet_key_prefix = "model."
|
||||||
else:
|
else:
|
||||||
unet_key_prefix = "model.diffusion_model."
|
unet_key_prefix = "model.diffusion_model."
|
||||||
return unet_key_prefix
|
return unet_key_prefix
|
||||||
@ -436,12 +453,19 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
out_sd = None
|
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
|
||||||
if num_blocks > 0:
|
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
|
|
||||||
|
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||||
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
|
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
||||||
|
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||||
|
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
for k in sd_map:
|
for k in sd_map:
|
||||||
weight = state_dict.get(k, None)
|
weight = state_dict.get(k, None)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
|
|||||||
@ -60,6 +60,12 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl
|
|||||||
model_options["disable_cfg1_optimization"] = True
|
model_options["disable_cfg1_optimization"] = True
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
model_options["disable_cfg1_optimization"] = True
|
||||||
|
return model_options
|
||||||
|
|
||||||
class ModelPatcher(ModelManageable):
|
class ModelPatcher(ModelManageable):
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -142,6 +148,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
|
|
||||||
|
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
|
|||||||
@ -192,11 +192,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
sampling_settings = {}
|
sampling_settings = {}
|
||||||
|
|
||||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
|
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
|
||||||
|
|
||||||
def set_parameters(self, shift=1.0, timesteps=1000):
|
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
|
||||||
self.shift = shift
|
self.shift = shift
|
||||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
self.multiplier = multiplier
|
||||||
|
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
|
||||||
self.register_buffer('sigmas', ts)
|
self.register_buffer('sigmas', ts)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -208,10 +209,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
|||||||
return self.sigmas[-1]
|
return self.sigmas[-1]
|
||||||
|
|
||||||
def timestep(self, sigma):
|
def timestep(self, sigma):
|
||||||
return sigma * 1000
|
return sigma * self.multiplier
|
||||||
|
|
||||||
def sigma(self, timestep):
|
def sigma(self, timestep):
|
||||||
return time_snr_shift(self.shift, timestep / 1000)
|
return time_snr_shift(self.shift, timestep / self.multiplier)
|
||||||
|
|
||||||
def percent_to_sigma(self, percent):
|
def percent_to_sigma(self, percent):
|
||||||
if percent <= 0.0:
|
if percent <= 0.0:
|
||||||
|
|||||||
@ -46,8 +46,9 @@ class CLIPTextEncode:
|
|||||||
|
|
||||||
def encode(self, clip, text):
|
def encode(self, clip, text):
|
||||||
tokens = clip.tokenize(text)
|
tokens = clip.tokenize(text)
|
||||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||||
return ([[cond, {"pooled_output": pooled}]], )
|
cond = output.pop("cond")
|
||||||
|
return ([[cond, output]], )
|
||||||
|
|
||||||
class ConditioningCombine:
|
class ConditioningCombine:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -223,8 +224,9 @@ class ConditioningZeroOut:
|
|||||||
c = []
|
c = []
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
d = t[1].copy()
|
d = t[1].copy()
|
||||||
if "pooled_output" in d:
|
pooled_output = d.get("pooled_output", None)
|
||||||
d["pooled_output"] = torch.zeros_like(d["pooled_output"])
|
if pooled_output is not None:
|
||||||
|
d["pooled_output"] = torch.zeros_like(pooled_output)
|
||||||
n = [torch.zeros_like(t[0]), d]
|
n = [torch.zeros_like(t[0]), d]
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|||||||
@ -1,22 +1,27 @@
|
|||||||
from comfy import sd1_clip
|
|
||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
import comfy.t5
|
import comfy.t5
|
||||||
import os
|
from comfy import sd1_clip
|
||||||
|
from comfy.component_model import files
|
||||||
|
|
||||||
|
|
||||||
class T5BaseModel(sd1_clip.SDClipModel):
|
class T5BaseModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json")
|
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
|
|
||||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None):
|
def __init__(self, embedding_directory=None):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
|
||||||
|
|
||||||
|
|
||||||
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None):
|
def __init__(self, embedding_directory=None):
|
||||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer)
|
||||||
|
|
||||||
|
|
||||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||||
|
|||||||
@ -278,6 +278,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
|
|
||||||
conds = [cond, uncond_]
|
conds = [cond, uncond_]
|
||||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||||
|
|
||||||
|
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||||
|
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||||
|
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||||
|
out = fn(args)
|
||||||
|
|
||||||
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
72
comfy/sd.py
72
comfy/sd.py
@ -28,37 +28,7 @@ from .t2i_adapter import adapter
|
|||||||
from .taesd import taesd
|
from .taesd import taesd
|
||||||
from . import sd3_clip
|
from . import sd3_clip
|
||||||
from . import sa_t5
|
from . import sa_t5
|
||||||
|
from .text_encoders import aura_t5
|
||||||
|
|
||||||
def load_model_weights(model, sd):
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
m = set(m)
|
|
||||||
unexpected_keys = set(u)
|
|
||||||
|
|
||||||
k = list(sd.keys())
|
|
||||||
for x in k:
|
|
||||||
if x not in unexpected_keys:
|
|
||||||
w = sd.pop(x)
|
|
||||||
del w
|
|
||||||
if len(m) > 0:
|
|
||||||
logging.warning("missing {}".format(m))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_weights(model, sd):
|
|
||||||
k = list(sd.keys())
|
|
||||||
for x in k:
|
|
||||||
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
|
||||||
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
|
||||||
sd[y] = sd.pop(x)
|
|
||||||
|
|
||||||
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
|
|
||||||
ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
|
||||||
if ids.dtype == torch.float32:
|
|
||||||
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
|
||||||
|
|
||||||
sd = utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.")
|
|
||||||
return load_model_weights(model, sd)
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||||
@ -136,7 +106,7 @@ class CLIP:
|
|||||||
def tokenize(self, text, return_word_ids=False):
|
def tokenize(self, text, return_word_ids=False):
|
||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
@ -146,7 +116,15 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
cond, pooled = o[:2]
|
||||||
|
if return_dict:
|
||||||
|
out = {"cond": cond, "pooled_output": pooled}
|
||||||
|
if len(o) > 2:
|
||||||
|
for k in o[2]:
|
||||||
|
out[k] = o[2][k]
|
||||||
|
return out
|
||||||
|
|
||||||
if return_pooled:
|
if return_pooled:
|
||||||
return cond, pooled
|
return cond, pooled
|
||||||
return cond
|
return cond
|
||||||
@ -447,9 +425,14 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
|
||||||
dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype
|
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||||
|
dtype_t5 = weight.dtype
|
||||||
|
if weight.shape[-1] == 4096:
|
||||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
|
||||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||||
|
elif weight.shape[-1] == 2048:
|
||||||
|
clip_target.clip = aura_t5.AuraT5Model
|
||||||
|
clip_target.tokenizer = aura_t5.AuraT5Tokenizer
|
||||||
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
|
||||||
clip_target.clip = sa_t5.SAT5Model
|
clip_target.clip = sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||||
@ -529,13 +512,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
|
if model_config is None:
|
||||||
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
if model_config is None:
|
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
|
||||||
|
|
||||||
if model_config.clip_vision_prefix is not None:
|
if model_config.clip_vision_prefix is not None:
|
||||||
if output_clipvision:
|
if output_clipvision:
|
||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
@ -586,30 +569,25 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
def load_unet_state_dict(sd): # load unet in diffusers or regular format
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
checkpoint = False
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||||
if len(temp_sd) > 0:
|
if len(temp_sd) > 0:
|
||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
checkpoint = True
|
|
||||||
|
|
||||||
parameters = utils.calculate_parameters(sd)
|
parameters = utils.calculate_parameters(sd)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: # ldm or stable cascade
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
if model_config is None:
|
|
||||||
return None
|
if model_config is not None:
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
|
else:
|
||||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
if new_sd is None:
|
if new_sd is not None: #diffusers mmdit
|
||||||
return None
|
|
||||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
else: # diffusers
|
else: # diffusers unet
|
||||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import numbers
|
import numbers
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from importlib.abc import Traversable
|
||||||
from typing import Tuple, Sequence, TypeVar
|
from typing import Tuple, Sequence, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -14,6 +16,7 @@ from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMi
|
|||||||
from . import clip_model
|
from . import clip_model
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import ops
|
from . import ops
|
||||||
|
from .component_model import files
|
||||||
from .component_model.files import get_path_as_dict, get_package_as_path
|
from .component_model.files import get_path_as_dict, get_package_as_path
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +32,58 @@ def gen_empty_tokens(special_tokens, length):
|
|||||||
output += [pad_token] * (length - len(output))
|
output += [pad_token] * (length - len(output))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module):
|
class ClipTokenWeightEncoder:
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
to_encode = list()
|
||||||
|
max_token_len = 0
|
||||||
|
has_weights = False
|
||||||
|
for x in token_weight_pairs:
|
||||||
|
tokens = list(map(lambda a: a[0], x))
|
||||||
|
max_token_len = max(len(tokens), max_token_len)
|
||||||
|
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||||
|
to_encode.append(tokens)
|
||||||
|
|
||||||
|
sections = len(to_encode)
|
||||||
|
if has_weights or sections == 0:
|
||||||
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
|
o = self.encode(to_encode)
|
||||||
|
out, pooled = o[:2]
|
||||||
|
|
||||||
|
if pooled is not None:
|
||||||
|
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||||
|
else:
|
||||||
|
first_pooled = pooled
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for k in range(0, sections):
|
||||||
|
z = out[k:k+1]
|
||||||
|
if has_weights:
|
||||||
|
z_empty = out[-1]
|
||||||
|
for i in range(len(z)):
|
||||||
|
for j in range(len(z[i])):
|
||||||
|
weight = token_weight_pairs[k][j][1]
|
||||||
|
if weight != 1.0:
|
||||||
|
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||||
|
output.append(z)
|
||||||
|
|
||||||
|
if (len(output) == 0):
|
||||||
|
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||||
|
else:
|
||||||
|
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||||
|
|
||||||
|
if len(o) > 2:
|
||||||
|
extra = {}
|
||||||
|
for k in o[2]:
|
||||||
|
v = o[2][k]
|
||||||
|
if k == "attention_mask":
|
||||||
|
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||||
|
extra[k] = v
|
||||||
|
|
||||||
|
r = r + (extra,)
|
||||||
|
return r
|
||||||
|
|
||||||
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
@ -40,7 +94,7 @@ class SDClipModel(torch.nn.Module):
|
|||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||||
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if special_tokens is None:
|
if special_tokens is None:
|
||||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||||
@ -63,6 +117,7 @@ class SDClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
self.return_projected_pooled = return_projected_pooled
|
self.return_projected_pooled = return_projected_pooled
|
||||||
|
self.return_attention_masks = return_attention_masks
|
||||||
|
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
@ -136,7 +191,7 @@ class SDClipModel(torch.nn.Module):
|
|||||||
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
|
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
|
||||||
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
attention_mask = torch.zeros_like(tokens)
|
||||||
end_token = self.special_tokens.get("end", -1)
|
end_token = self.special_tokens.get("end", -1)
|
||||||
for x in range(attention_mask.shape[0]):
|
for x in range(attention_mask.shape[0]):
|
||||||
@ -145,7 +200,11 @@ class SDClipModel(torch.nn.Module):
|
|||||||
if tokens[x, y] == end_token:
|
if tokens[x, y] == end_token:
|
||||||
break
|
break
|
||||||
|
|
||||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
attention_mask_model = None
|
||||||
|
if self.enable_attention_masks:
|
||||||
|
attention_mask_model = attention_mask
|
||||||
|
|
||||||
|
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
@ -153,7 +212,7 @@ class SDClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
z = outputs[1].float()
|
z = outputs[1].float()
|
||||||
|
|
||||||
if self.zero_out_masked and attention_mask is not None:
|
if self.zero_out_masked:
|
||||||
z *= attention_mask.unsqueeze(-1).float()
|
z *= attention_mask.unsqueeze(-1).float()
|
||||||
|
|
||||||
pooled_output = None
|
pooled_output = None
|
||||||
@ -163,6 +222,13 @@ class SDClipModel(torch.nn.Module):
|
|||||||
elif outputs[2] is not None:
|
elif outputs[2] is not None:
|
||||||
pooled_output = outputs[2].float()
|
pooled_output = outputs[2].float()
|
||||||
|
|
||||||
|
extra = {}
|
||||||
|
if self.return_attention_masks:
|
||||||
|
extra["attention_mask"] = attention_mask
|
||||||
|
|
||||||
|
if len(extra) > 0:
|
||||||
|
return z, pooled_output, extra
|
||||||
|
|
||||||
return z, pooled_output
|
return z, pooled_output
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
@ -374,10 +440,13 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
|
|||||||
|
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = files.get_package_as_path("comfy.sd1_tokenizer")
|
||||||
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
if isinstance(tokenizer_path, Traversable):
|
||||||
|
contextlib_path = importlib.resources.as_file(tokenizer_path)
|
||||||
|
tokenizer_path = contextlib_path.__enter__()
|
||||||
|
if not tokenizer_path.endswith(".model") and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
||||||
# package based
|
# package based
|
||||||
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
|
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
|
||||||
self.tokenizer_class = tokenizer_class
|
self.tokenizer_class = tokenizer_class
|
||||||
@ -395,6 +464,14 @@ class SDTokenizer:
|
|||||||
self.tokens_start = 0
|
self.tokens_start = 0
|
||||||
self.start_token = None
|
self.start_token = None
|
||||||
self.end_token = empty[0]
|
self.end_token = empty[0]
|
||||||
|
|
||||||
|
if pad_token is not None:
|
||||||
|
self.pad_token = pad_token
|
||||||
|
elif pad_with_end:
|
||||||
|
self.pad_token = self.end_token
|
||||||
|
else:
|
||||||
|
self.pad_token = 0
|
||||||
|
|
||||||
self.pad_with_end = pad_with_end
|
self.pad_with_end = pad_with_end
|
||||||
self.pad_to_max_length = pad_to_max_length
|
self.pad_to_max_length = pad_to_max_length
|
||||||
self.additional_tokens: Tuple[str, ...] = ()
|
self.additional_tokens: Tuple[str, ...] = ()
|
||||||
@ -439,10 +516,6 @@ class SDTokenizer:
|
|||||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||||
'''
|
'''
|
||||||
if self.pad_with_end:
|
|
||||||
pad_token = self.end_token
|
|
||||||
else:
|
|
||||||
pad_token = 0
|
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
@ -502,7 +575,7 @@ class SDTokenizer:
|
|||||||
else:
|
else:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||||
# start new batch
|
# start new batch
|
||||||
batch = []
|
batch = []
|
||||||
if self.start_token is not None:
|
if self.start_token is not None:
|
||||||
@ -515,9 +588,9 @@ class SDTokenizer:
|
|||||||
# fill last batch
|
# fill last batch
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if self.min_length is not None and len(batch) < self.min_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||||
@ -560,10 +633,16 @@ class SD1Tokenizer:
|
|||||||
|
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
self.clip_name = name
|
||||||
|
self.clip = "{}".format(self.clip_name)
|
||||||
|
else:
|
||||||
self.clip_name = clip_name
|
self.clip_name = clip_name
|
||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
|
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
|
||||||
|
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
@ -578,8 +657,8 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||||
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||||
return out, pooled
|
return out
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return getattr(self, self.clip).load_sd(sd)
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
|||||||
@ -1,32 +1,38 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.t5
|
||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
from comfy import sdxl_clip
|
from comfy import sdxl_clip
|
||||||
from transformers import T5TokenizerFast
|
from comfy.component_model import files
|
||||||
import comfy.t5
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import comfy.model_management
|
|
||||||
import logging
|
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json")
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||||
|
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None):
|
def __init__(self, embedding_directory=None):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||||
|
|
||||||
|
|
||||||
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None):
|
def __init__(self, embedding_directory=None):
|
||||||
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
class SDT5XXLModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SD3Tokenizer:
|
class SD3Tokenizer:
|
||||||
def __init__(self, embedding_directory=None):
|
def __init__(self, embedding_directory=None):
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||||
@ -43,6 +49,7 @@ class SD3Tokenizer:
|
|||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
return self.clip_g.untokenize(token_weight_pair)
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
|
||||||
class SD3ClipModel(torch.nn.Module):
|
class SD3ClipModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -143,8 +150,10 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||||
|
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from . import sd2_clip
|
|||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
from . import sd3_clip
|
from . import sd3_clip
|
||||||
from . import sa_t5
|
from . import sa_t5
|
||||||
|
from .text_encoders import aura_t5
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -556,7 +557,29 @@ class StableAudio(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||||
|
|
||||||
|
class AuraFlow(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"cond_seq_dim": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
|
sampling_settings = {
|
||||||
|
"multiplier": 1.0,
|
||||||
|
"shift": 1.73,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.AuraFlow(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(aura_t5.AuraT5Tokenizer, aura_t5.AuraT5Model)
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
35
comfy/t5.py
35
comfy/t5.py
@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module):
|
|||||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||||
|
|
||||||
|
activations = {
|
||||||
|
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||||
|
"relu": torch.nn.functional.relu,
|
||||||
|
}
|
||||||
|
|
||||||
class T5DenseActDense(torch.nn.Module):
|
class T5DenseActDense(torch.nn.Module):
|
||||||
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
self.act = activations[ff_activation]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = torch.nn.functional.relu(self.wi(x))
|
x = self.act(self.wi(x))
|
||||||
# x = self.dropout(x)
|
# x = self.dropout(x)
|
||||||
x = self.wo(x)
|
x = self.wo(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class T5DenseGatedActDense(torch.nn.Module):
|
class T5DenseGatedActDense(torch.nn.Module):
|
||||||
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
self.act = activations[ff_activation]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
hidden_gelu = self.act(self.wi_0(x))
|
||||||
hidden_linear = self.wi_1(x)
|
hidden_linear = self.wi_1(x)
|
||||||
x = hidden_gelu * hidden_linear
|
x = hidden_gelu * hidden_linear
|
||||||
# x = self.dropout(x)
|
# x = self.dropout(x)
|
||||||
@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class T5LayerFF(torch.nn.Module):
|
class T5LayerFF(torch.nn.Module):
|
||||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if ff_activation == "gelu_pytorch_tanh":
|
if gated_act:
|
||||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations)
|
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||||
elif ff_activation == "relu":
|
else:
|
||||||
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations)
|
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||||
|
|
||||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module):
|
|||||||
return x, past_bias
|
return x, past_bias
|
||||||
|
|
||||||
class T5Block(torch.nn.Module):
|
class T5Block(torch.nn.Module):
|
||||||
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations):
|
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer = torch.nn.ModuleList()
|
self.layer = torch.nn.ModuleList()
|
||||||
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
||||||
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations))
|
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
|
||||||
|
|
||||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||||
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
||||||
@ -183,11 +190,11 @@ class T5Block(torch.nn.Module):
|
|||||||
return x, past_bias
|
return x, past_bias
|
||||||
|
|
||||||
class T5Stack(torch.nn.Module):
|
class T5Stack(torch.nn.Module):
|
||||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations):
|
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.block = torch.nn.ModuleList(
|
self.block = torch.nn.ModuleList(
|
||||||
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0)), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
@ -216,7 +223,7 @@ class T5(torch.nn.Module):
|
|||||||
self.num_layers = config_dict["num_layers"]
|
self.num_layers = config_dict["num_layers"]
|
||||||
model_dim = config_dict["d_model"]
|
model_dim = config_dict["d_model"]
|
||||||
|
|
||||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations)
|
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
"dense_act_fn": "relu",
|
"dense_act_fn": "relu",
|
||||||
"initializer_factor": 1.0,
|
"initializer_factor": 1.0,
|
||||||
"is_encoder_decoder": true,
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": false,
|
||||||
"layer_norm_epsilon": 1e-06,
|
"layer_norm_epsilon": 1e-06,
|
||||||
"model_type": "t5",
|
"model_type": "t5",
|
||||||
"num_decoder_layers": 12,
|
"num_decoder_layers": 12,
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
"dense_act_fn": "gelu_pytorch_tanh",
|
"dense_act_fn": "gelu_pytorch_tanh",
|
||||||
"initializer_factor": 1.0,
|
"initializer_factor": 1.0,
|
||||||
"is_encoder_decoder": true,
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": true,
|
||||||
"layer_norm_epsilon": 1e-06,
|
"layer_norm_epsilon": 1e-06,
|
||||||
"model_type": "t5",
|
"model_type": "t5",
|
||||||
"num_decoder_layers": 24,
|
"num_decoder_layers": 24,
|
||||||
|
|||||||
0
comfy/text_encoders/__init__.py
Normal file
0
comfy/text_encoders/__init__.py
Normal file
28
comfy/text_encoders/aura_t5.py
Normal file
28
comfy/text_encoders/aura_t5.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from importlib import resources
|
||||||
|
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from .llama_tokenizer import LLAMATokenizer
|
||||||
|
from .. import t5
|
||||||
|
from ..component_model.files import get_path_as_dict
|
||||||
|
|
||||||
|
|
||||||
|
class PT5XlModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||||
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_pile_config_xl.json", package="comfy.text_encoders")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
tokenizer_path = resources.files("comfy.text_encoders.t5_pile_tokenizer") / "tokenizer.model"
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
|
||||||
|
|
||||||
|
|
||||||
|
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||||
22
comfy/text_encoders/llama_tokenizer.py
Normal file
22
comfy/text_encoders/llama_tokenizer.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
class LLAMATokenizer:
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(path):
|
||||||
|
return LLAMATokenizer(path)
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_path):
|
||||||
|
import sentencepiece
|
||||||
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
|
||||||
|
self.end = self.tokenizer.eos_id()
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
out = {}
|
||||||
|
for i in range(self.tokenizer.get_piece_size()):
|
||||||
|
out[self.tokenizer.id_to_piece(i)] = i
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(self, string):
|
||||||
|
out = self.tokenizer.encode(string)
|
||||||
|
out += [self.end]
|
||||||
|
return {"input_ids": out}
|
||||||
22
comfy/text_encoders/t5_pile_config_xl.json
Normal file
22
comfy/text_encoders/t5_pile_config_xl.json
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"d_ff": 5120,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 2048,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"dense_act_fn": "gelu_pytorch_tanh",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "umt5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 32,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
||||||
0
comfy/text_encoders/t5_pile_tokenizer/__init__.py
Normal file
0
comfy/text_encoders/t5_pile_tokenizer/__init__.py
Normal file
BIN
comfy/text_encoders/t5_pile_tokenizer/tokenizer.model
Normal file
BIN
comfy/text_encoders/t5_pile_tokenizer/tokenizer.model
Normal file
Binary file not shown.
@ -20,6 +20,8 @@ from PIL import Image
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import checkpoint_pickle, interruption
|
from . import checkpoint_pickle, interruption
|
||||||
|
from .component_model import files
|
||||||
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||||
from .component_model.queue_types import BinaryEventTypes
|
from .component_model.queue_types import BinaryEventTypes
|
||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context
|
||||||
@ -374,6 +376,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||||
|
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||||
|
n_layers = mmdit_config.get("n_layers", 0)
|
||||||
|
|
||||||
|
key_map = {}
|
||||||
|
for i in range(n_layers):
|
||||||
|
if i < n_double_layers:
|
||||||
|
index = i
|
||||||
|
prefix_from = "joint_transformer_blocks"
|
||||||
|
prefix_to = "{}double_layers".format(output_prefix)
|
||||||
|
block_map = {
|
||||||
|
"attn.to_q.weight": "attn.w2q.weight",
|
||||||
|
"attn.to_k.weight": "attn.w2k.weight",
|
||||||
|
"attn.to_v.weight": "attn.w2v.weight",
|
||||||
|
"attn.to_out.0.weight": "attn.w2o.weight",
|
||||||
|
"attn.add_q_proj.weight": "attn.w1q.weight",
|
||||||
|
"attn.add_k_proj.weight": "attn.w1k.weight",
|
||||||
|
"attn.add_v_proj.weight": "attn.w1v.weight",
|
||||||
|
"attn.to_add_out.weight": "attn.w1o.weight",
|
||||||
|
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
||||||
|
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
||||||
|
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
||||||
|
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
||||||
|
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
||||||
|
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
||||||
|
"norm1.linear.weight": "modX.1.weight",
|
||||||
|
"norm1_context.linear.weight": "modC.1.weight",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
index = i - n_double_layers
|
||||||
|
prefix_from = "single_transformer_blocks"
|
||||||
|
prefix_to = "{}single_layers".format(output_prefix)
|
||||||
|
|
||||||
|
block_map = {
|
||||||
|
"attn.to_q.weight": "attn.w1q.weight",
|
||||||
|
"attn.to_k.weight": "attn.w1k.weight",
|
||||||
|
"attn.to_v.weight": "attn.w1v.weight",
|
||||||
|
"attn.to_out.0.weight": "attn.w1o.weight",
|
||||||
|
"norm1.linear.weight": "modCX.1.weight",
|
||||||
|
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
||||||
|
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
||||||
|
"ff.out_projection.weight": "mlp.c_proj.weight"
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in block_map:
|
||||||
|
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
||||||
|
|
||||||
|
MAP_BASIC = {
|
||||||
|
("positional_encoding", "pos_embed.pos_embed"),
|
||||||
|
("register_tokens", "register_tokens"),
|
||||||
|
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
||||||
|
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
||||||
|
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
||||||
|
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
||||||
|
("cond_seq_linear.weight", "context_embedder.weight"),
|
||||||
|
("init_x_linear.weight", "pos_embed.proj.weight"),
|
||||||
|
("init_x_linear.bias", "pos_embed.proj.bias"),
|
||||||
|
("final_linear.weight", "proj_out.weight"),
|
||||||
|
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in MAP_BASIC:
|
||||||
|
if len(k) > 2:
|
||||||
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
||||||
|
else:
|
||||||
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||||
if tensor.shape[dim] > batch_size:
|
if tensor.shape[dim] > batch_size:
|
||||||
return tensor.narrow(dim, 0, batch_size)
|
return tensor.narrow(dim, 0, batch_size)
|
||||||
@ -675,8 +747,9 @@ class ProgressBar:
|
|||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
|
|
||||||
|
|
||||||
|
@_deprecate_method(version="1.0.0", message="The root project directory isn't valid when the application is installed as a package. Use os.getcwd() instead.")
|
||||||
def get_project_root() -> str:
|
def get_project_root() -> str:
|
||||||
return os.path.join(os.path.dirname(__file__), "..")
|
return files.get_package_as_path("comfy")
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@ -1599,7 +1599,7 @@ export class ComfyApp {
|
|||||||
if (json) {
|
if (json) {
|
||||||
const workflow = JSON.parse(json);
|
const workflow = JSON.parse(json);
|
||||||
const workflowName = getStorageValue("Comfy.PreviousWorkflow");
|
const workflowName = getStorageValue("Comfy.PreviousWorkflow");
|
||||||
await this.loadGraphData(workflow, true, workflowName);
|
await this.loadGraphData(workflow, true, true, workflowName);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -182,6 +182,11 @@ export class ComfyWorkflowsMenu {
|
|||||||
* @param {ComfyWorkflow} workflow
|
* @param {ComfyWorkflow} workflow
|
||||||
*/
|
*/
|
||||||
async function sendToWorkflow(img, workflow) {
|
async function sendToWorkflow(img, workflow) {
|
||||||
|
const openWorkflow = app.workflowManager.openWorkflows.find((w) => w.path === workflow.path);
|
||||||
|
if (openWorkflow) {
|
||||||
|
workflow = openWorkflow;
|
||||||
|
}
|
||||||
|
|
||||||
await workflow.load();
|
await workflow.load();
|
||||||
let options = [];
|
let options = [];
|
||||||
const nodes = app.graph.computeExecutionOrder(false);
|
const nodes = app.graph.computeExecutionOrder(false);
|
||||||
@ -214,7 +219,8 @@ export class ComfyWorkflowsMenu {
|
|||||||
nodeType.prototype["getExtraMenuOptions"] = function (_, options) {
|
nodeType.prototype["getExtraMenuOptions"] = function (_, options) {
|
||||||
const r = getExtraMenuOptions?.apply?.(this, arguments);
|
const r = getExtraMenuOptions?.apply?.(this, arguments);
|
||||||
|
|
||||||
if (app.ui.settings.getSettingValue("Comfy.UseNewMenu", false) === true) {
|
const setting = app.ui.settings.getSettingValue("Comfy.UseNewMenu", false);
|
||||||
|
if (setting && setting != "Disabled") {
|
||||||
const t = /** @type { {imageIndex?: number, overIndex?: number, imgs: string[]} } */ /** @type {any} */ (this);
|
const t = /** @type { {imageIndex?: number, overIndex?: number, imgs: string[]} } */ /** @type {any} */ (this);
|
||||||
let img;
|
let img;
|
||||||
if (t.imageIndex != null) {
|
if (t.imageIndex != null) {
|
||||||
|
|||||||
@ -41,7 +41,7 @@ body {
|
|||||||
background-color: var(--bg-color);
|
background-color: var(--bg-color);
|
||||||
color: var(--fg-color);
|
color: var(--fg-color);
|
||||||
grid-template-columns: auto 1fr auto;
|
grid-template-columns: auto 1fr auto;
|
||||||
grid-template-rows: auto auto 1fr auto;
|
grid-template-rows: auto 1fr auto;
|
||||||
min-height: -webkit-fill-available;
|
min-height: -webkit-fill-available;
|
||||||
max-height: -webkit-fill-available;
|
max-height: -webkit-fill-available;
|
||||||
min-width: -webkit-fill-available;
|
min-width: -webkit-fill-available;
|
||||||
@ -49,32 +49,37 @@ body {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.comfyui-body-top {
|
.comfyui-body-top {
|
||||||
order: 0;
|
order: -5;
|
||||||
grid-column: 1/-1;
|
grid-column: 1/-1;
|
||||||
z-index: 10;
|
z-index: 10;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfyui-body-left {
|
.comfyui-body-left {
|
||||||
order: 1;
|
order: -4;
|
||||||
z-index: 10;
|
z-index: 10;
|
||||||
|
display: flex;
|
||||||
}
|
}
|
||||||
|
|
||||||
#graph-canvas {
|
#graph-canvas {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
order: 2;
|
order: -3;
|
||||||
grid-column: 1/-1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfyui-body-right {
|
.comfyui-body-right {
|
||||||
order: 3;
|
order: -2;
|
||||||
z-index: 10;
|
z-index: 10;
|
||||||
|
display: flex;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfyui-body-bottom {
|
.comfyui-body-bottom {
|
||||||
order: 4;
|
order: -1;
|
||||||
grid-column: 1/-1;
|
grid-column: 1/-1;
|
||||||
z-index: 10;
|
z-index: 10;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-multiline-input {
|
.comfy-multiline-input {
|
||||||
@ -408,8 +413,12 @@ dialog::backdrop {
|
|||||||
background: rgba(0, 0, 0, 0.5);
|
background: rgba(0, 0, 0, 0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-dialog.comfyui-dialog {
|
.comfy-dialog.comfyui-dialog.comfy-modal {
|
||||||
top: 0;
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
bottom: 0;
|
||||||
|
transform: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-dialog.comfy-modal {
|
.comfy-dialog.comfy-modal {
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class EmptyLatentAudio:
|
|||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def generate(self, seconds):
|
def generate(self, seconds):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
@ -35,7 +35,7 @@ class VAEEncodeAudio:
|
|||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def encode(self, vae, audio):
|
def encode(self, vae, audio):
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
@ -55,7 +55,7 @@ class VAEDecodeAudio:
|
|||||||
RETURN_TYPES = ("AUDIO",)
|
RETURN_TYPES = ("AUDIO",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/audio"
|
CATEGORY = "latent/audio"
|
||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
@ -134,7 +134,7 @@ class SaveAudio:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "_for_testing/audio"
|
CATEGORY = "audio"
|
||||||
|
|
||||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
import torchaudio # pylint: disable=import-error
|
import torchaudio # pylint: disable=import-error
|
||||||
@ -199,7 +199,7 @@ class LoadAudio:
|
|||||||
]
|
]
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||||
|
|
||||||
CATEGORY = "_for_testing/audio"
|
CATEGORY = "audio"
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", )
|
RETURN_TYPES = ("AUDIO", )
|
||||||
FUNCTION = "load"
|
FUNCTION = "load"
|
||||||
@ -209,7 +209,6 @@ class LoadAudio:
|
|||||||
|
|
||||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
waveform, sample_rate = torchaudio.load(audio_path)
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
multiplier = 1.0
|
|
||||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
return (audio, )
|
return (audio, )
|
||||||
|
|
||||||
|
|||||||
@ -147,7 +147,7 @@ class ModelSamplingSD3:
|
|||||||
|
|
||||||
CATEGORY = "advanced/model"
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
def patch(self, model, shift):
|
def patch(self, model, shift, multiplier=1000):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
|
sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow
|
||||||
@ -157,10 +157,22 @@ class ModelSamplingSD3:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
model_sampling.set_parameters(shift=shift)
|
model_sampling.set_parameters(shift=shift, multiplier=multiplier)
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class ModelSamplingAuraFlow(ModelSamplingSD3):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
FUNCTION = "patch_aura"
|
||||||
|
|
||||||
|
def patch_aura(self, model, shift):
|
||||||
|
return self.patch(model, shift, multiplier=1.0)
|
||||||
|
|
||||||
class ModelSamplingContinuousEDM:
|
class ModelSamplingContinuousEDM:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -276,5 +288,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelSamplingContinuousV": ModelSamplingContinuousV,
|
"ModelSamplingContinuousV": ModelSamplingContinuousV,
|
||||||
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
||||||
"ModelSamplingSD3": ModelSamplingSD3,
|
"ModelSamplingSD3": ModelSamplingSD3,
|
||||||
|
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
|
||||||
"RescaleCFG": RescaleCFG,
|
"RescaleCFG": RescaleCFG,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,8 @@ torchsde>=0.2.6
|
|||||||
einops>=0.6.0
|
einops>=0.6.0
|
||||||
open-clip-torch>=2.24.0
|
open-clip-torch>=2.24.0
|
||||||
transformers>=4.29.1
|
transformers>=4.29.1
|
||||||
|
tokenizers>=0.13.3
|
||||||
|
sentencepiece
|
||||||
peft
|
peft
|
||||||
torchinfo
|
torchinfo
|
||||||
safetensors>=0.4.2
|
safetensors>=0.4.2
|
||||||
|
|||||||
1
setup.py
1
setup.py
@ -192,6 +192,7 @@ package_data = [
|
|||||||
't5_tokenizer/*',
|
't5_tokenizer/*',
|
||||||
'**/*.json',
|
'**/*.json',
|
||||||
'**/*.yaml',
|
'**/*.yaml',
|
||||||
|
'**/*.model'
|
||||||
]
|
]
|
||||||
if not is_editable:
|
if not is_editable:
|
||||||
package_data.append('comfy/web/**/*')
|
package_data.append('comfy/web/**/*')
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
from aiohttp import ClientSession, ClientConnectorError
|
from aiohttp import ClientSession
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
|
||||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
@ -132,13 +132,11 @@ async def test_basic_queue_worker_with_health_check():
|
|||||||
health_check_port = 9090
|
health_check_port = 9090
|
||||||
|
|
||||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||||
# Test health check
|
|
||||||
health_check_url = f"http://localhost:{health_check_port}/health"
|
health_check_url = f"http://localhost:{health_check_port}/health"
|
||||||
|
|
||||||
health_check_ok = await check_health(health_check_url)
|
health_check_ok = await check_health(health_check_url)
|
||||||
assert health_check_ok, "Health check server did not start properly"
|
assert health_check_ok, "Health check server did not start properly"
|
||||||
|
|
||||||
# Test the actual worker functionality
|
|
||||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||||
await distributed_queue.init()
|
await distributed_queue.init()
|
||||||
@ -153,53 +151,5 @@ async def test_basic_queue_worker_with_health_check():
|
|||||||
|
|
||||||
await distributed_queue.close()
|
await distributed_queue.close()
|
||||||
|
|
||||||
# Test that the health check server is stopped after the worker is closed
|
|
||||||
health_check_stopped = not await check_health(health_check_url, max_retries=1)
|
health_check_stopped = not await check_health(health_check_url, max_retries=1)
|
||||||
assert health_check_stopped, "Health check server did not stop properly"
|
assert health_check_stopped, "Health check server did not stop properly"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_health_check_port_conflict():
|
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
||||||
params = rabbitmq.get_connection_params()
|
|
||||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
|
||||||
health_check_port = 9090
|
|
||||||
|
|
||||||
# Start a simple server to occupy the health check port
|
|
||||||
from aiohttp import web
|
|
||||||
async def dummy_handler(request):
|
|
||||||
return web.Response(text="Dummy")
|
|
||||||
|
|
||||||
app = web.Application()
|
|
||||||
app.router.add_get('/', dummy_handler)
|
|
||||||
runner = web.AppRunner(app)
|
|
||||||
await runner.setup()
|
|
||||||
site = web.TCPSite(runner, '0.0.0.0', health_check_port)
|
|
||||||
await site.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Now try to start the DistributedPromptWorker
|
|
||||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
|
||||||
# The health check should be disabled, but the worker should still function
|
|
||||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
|
||||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
|
||||||
await distributed_queue.init()
|
|
||||||
|
|
||||||
queue_item = create_test_prompt()
|
|
||||||
res = await distributed_queue.put_async(queue_item)
|
|
||||||
|
|
||||||
assert res.item_id == queue_item.prompt_id
|
|
||||||
assert len(res.outputs) == 1
|
|
||||||
assert res.status is not None
|
|
||||||
assert res.status.status_str == "success"
|
|
||||||
|
|
||||||
await distributed_queue.close()
|
|
||||||
|
|
||||||
# The original server should still be running
|
|
||||||
async with ClientSession() as session:
|
|
||||||
async with session.get(f"http://localhost:{health_check_port}") as response:
|
|
||||||
assert response.status == 200
|
|
||||||
assert await response.text() == "Dummy"
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await runner.cleanup()
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from comfy.api.components.schema.prompt import Prompt
|
from comfy.api.components.schema.prompt import Prompt
|
||||||
from comfy.cli_args_types import Configuration
|
|
||||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
||||||
from comfy.model_downloader_types import CivitFile
|
from comfy.model_downloader_types import CivitFile
|
||||||
@ -139,9 +138,7 @@ _workflows = {
|
|||||||
@pytest.fixture(scope="module", autouse=False)
|
@pytest.fixture(scope="module", autouse=False)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def client(tmp_path_factory) -> EmbeddedComfyClient:
|
async def client(tmp_path_factory) -> EmbeddedComfyClient:
|
||||||
config = Configuration()
|
async with EmbeddedComfyClient() as client:
|
||||||
config.cwd = str(tmp_path_factory.mktemp("comfy_test_cwd"))
|
|
||||||
async with EmbeddedComfyClient(config) as client:
|
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user