diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 87bef3681..39d1992d7 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -1,45 +1,48 @@ name: Bug Report description: "Something is broken inside of ComfyUI. (Do not use this if you're just having issues and need help, or if the issue relates to a custom node)" -labels: [ "Potential Bug" ] +labels: ["Potential Bug"] body: - - type: markdown - attributes: - value: | - Before submitting a **Bug Report**, please ensure the following: + - type: markdown + attributes: + value: | + Before submitting a **Bug Report**, please ensure the following: - **1:** You are running the latest version of ComfyUI. - **2:** You have looked at the existing bug reports and made sure this isn't already reported. - **3:** This is an actual bug in ComfyUI, not just a support question and not caused by an custom node. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. + - **1:** You are running the latest version of ComfyUI. + - **2:** You have looked at the existing bug reports and made sure this isn't already reported. + - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing + `--disable-all-custom-nodes` command line argument. + - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact + steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. - If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. - - type: textarea - attributes: - label: Expected Behavior - description: "What you expected to happen." - validations: - required: true - - type: textarea - attributes: - label: Actual Behavior - description: "What actually happened. Please include a screenshot of the issue if possible." - validations: - required: true - - type: textarea - attributes: - label: Steps to Reproduce - description: "Describe how to reproduce the issue. Please be sure to attach a workflow JSON or PNG, ideally one that doesn't require custom nodes to test. If the bug open happens when certain custom nodes are used, most likely that custom node is what has the bug rather than ComfyUI, in which case it should be reported to the node's author." - validations: - required: true - - type: textarea - attributes: - label: Debug Logs - description: "Please copy the output from your terminal logs here." - render: powershell - validations: - required: true - - type: textarea - attributes: - label: Other - description: "Any other additional information you think might be helpful." - validations: - required: false + If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + - type: textarea + attributes: + label: Expected Behavior + description: "What you expected to happen." + validations: + required: true + - type: textarea + attributes: + label: Actual Behavior + description: "What actually happened. Please include a screenshot of the issue if possible." + validations: + required: true + - type: textarea + attributes: + label: Steps to Reproduce + description: "Describe how to reproduce the issue. Please be sure to attach a workflow JSON or PNG, ideally one that doesn't require custom nodes to test. If the bug open happens when certain custom nodes are used, most likely that custom node is what has the bug rather than ComfyUI, in which case it should be reported to the node's author." + validations: + required: true + - type: textarea + attributes: + label: Debug Logs + description: "Please copy the output from your terminal logs here." + render: powershell + validations: + required: true + - type: textarea + attributes: + label: Other + description: "Any other additional information you think might be helpful." + validations: + required: false diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml new file mode 100644 index 000000000..1fd76b530 --- /dev/null +++ b/.github/workflows/stable-release.yml @@ -0,0 +1,109 @@ + +name: "Release Stable Version" + +on: + push: + tags: + - 'v*' + +jobs: + package_comfy_windows: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + runs-on: windows-latest + strategy: + matrix: + python_version: [3.11.8] + cuda_version: [121] + steps: + - name: Calculate Minor Version + shell: bash + run: | + # Extract the minor version from the Python version + MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2) + echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python_version }} + + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + - shell: bash + run: | + echo "@echo off + call update_comfyui.bat nopause + echo - + echo This will try to update pytorch and all python dependencies. + echo - + echo If you just want to update normally, close this and run update_comfyui.bat instead. + echo - + pause + ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2 + pause" > update_comfyui_and_python_dependencies.bat + + python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps + mv cu${{ matrix.cuda_version }}_python_deps ../ + mv update_comfyui_and_python_dependencies.bat ../ + cd .. + pwd + ls + + cp -r ComfyUI ComfyUI_copy + curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip + unzip python_embeded.zip -d python_embeded + cd python_embeded + echo ${{ env.MINOR_VERSION }} + echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py + ./python.exe get-pip.py + ./python.exe --version + echo "Pip version:" + ./python.exe -m pip --version + + set PATH=$PWD/Scripts:$PATH + echo $PATH + ./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/* + sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth + cd .. + + git clone https://github.com/comfyanonymous/taesd + cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ + + mkdir ComfyUI_windows_portable + mv python_embeded ComfyUI_windows_portable + mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI + + cd ComfyUI_windows_portable + + mkdir update + cp -r ComfyUI/.ci/update_windows/* ./update/ + cp -r ComfyUI/.ci/windows_base_files/* ./ + cp ../update_comfyui_and_python_dependencies.bat ./update/ + + cd .. + + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z + + cd ComfyUI_windows_portable + python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu + + ls + + - name: Upload binaries to release + uses: svenstaro/upload-release-action@v2 + with: + repo_token: ${{ secrets.GITHUB_TOKEN }} + file: ComfyUI_windows_portable_nvidia.7z + tag: ${{ github.ref }} + overwrite: true + diff --git a/.github/workflows/test-browser.yml b/.github/workflows/test-browser.yml index 0b70f91e0..f8e2ac6cc 100644 --- a/.github/workflows/test-browser.yml +++ b/.github/workflows/test-browser.yml @@ -41,7 +41,7 @@ jobs: working-directory: ComfyUI - name: Start ComfyUI server run: | - python main.py --cpu & + python main.py --cpu 2>&1 | tee console_output.log & wait-for-it --service 127.0.0.1:8188 -t 600 working-directory: ComfyUI - name: Install ComfyUI_frontend dependencies @@ -54,9 +54,22 @@ jobs: - name: Run Playwright tests run: npx playwright test working-directory: ComfyUI_frontend + - name: Check for unhandled exceptions in server log + run: | + if grep -qE "Exception|Error" console_output.log; then + echo "Unhandled exception/error found in server log." + exit 1 + fi + working-directory: ComfyUI - uses: actions/upload-artifact@v4 if: always() with: name: playwright-report path: ComfyUI_frontend/playwright-report/ retention-days: 30 + - uses: actions/upload-artifact@v4 + if: always() + with: + name: console-output + path: ComfyUI/console_output.log + retention-days: 30 diff --git a/README.md b/README.md index 52bbcc62a..6c25f7f28 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) +- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. - Works fully offline: will never download anything. diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 51e400e95..064dfce06 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -10,10 +10,51 @@ from ..ldm.modules.diffusionmodules.util import ( timestep_embedding, ) -from ..ldm.modules.attention import SpatialTransformer +from ..ldm.modules.attention import SpatialTransformer, optimized_attention from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.util import exists from .. import ops +from collections import OrderedDict + + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + self.c = c + + self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device) + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = self.in_proj(x) + q, k, v = x.split(self.c, dim=2) + out = optimized_attention(q, k, v, self.heads) + return self.out_proj(out) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResBlockUnionControlnet(nn.Module): + def __init__(self, dim, nhead, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations) + self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device) + self.mlp = nn.Sequential( + OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()), + ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))])) + self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device) + + def attention(self, x: torch.Tensor): + return self.attn(x) + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x class ControlledUnetModel(UNetModel): #implemented in the ldm unet @@ -53,6 +94,7 @@ class ControlNet(nn.Module): transformer_depth_middle=None, transformer_depth_output=None, attn_precision=None, + union_controlnet_num_control_type=None, device=None, operations=ops.disable_weight_init, **kwargs, @@ -280,6 +322,65 @@ class ControlNet(nn.Module): self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch + if union_controlnet_num_control_type is not None: + self.num_control_type = union_controlnet_num_control_type + num_trans_channel = 320 + num_trans_head = 8 + num_trans_layer = 1 + num_proj_channel = 320 + # task_scale_factor = num_trans_channel ** 0.5 + self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device)) + + self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)]) + self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device) + #----------------------------------------------------------------------------------------------------- + + control_add_embed_dim = 256 + class ControlAddEmbedding(nn.Module): + def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None): + super().__init__() + self.num_control_type = num_control_type + self.in_dim = in_dim + self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device) + self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device) + def forward(self, control_type, dtype, device): + c_type = torch.zeros((self.num_control_type,), device=device) + c_type[control_type] = 1.0 + c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim)) + return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type))) + + self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations) + else: + self.task_embedding = None + self.control_add_embedding = None + + def union_controlnet_merge(self, hint, control_type, emb, context): + # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main + inputs = [] + condition_list = [] + + for idx in range(min(1, len(control_type))): + controlnet_cond = self.input_hint_block(hint[idx], emb, context) + feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) + if idx < len(control_type): + feat_seq += self.task_embedding[control_type[idx]] + + inputs.append(feat_seq.unsqueeze(1)) + condition_list.append(controlnet_cond) + + x = torch.cat(inputs, dim=1) + x = self.transformer_layes(x) + controlnet_cond_fuser = None + for idx in range(len(control_type)): + alpha = self.spatial_ch_projs(x[:, idx]) + alpha = alpha.unsqueeze(-1).unsqueeze(-1) + o = condition_list[idx] + alpha + if controlnet_cond_fuser is None: + controlnet_cond_fuser = o + else: + controlnet_cond_fuser += o + return controlnet_cond_fuser + def make_zero_conv(self, channels, operations=None, dtype=None, device=None): return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) @@ -287,7 +388,18 @@ class ControlNet(nn.Module): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) - guided_hint = self.input_hint_block(hint, emb, context) + guided_hint = None + if self.control_add_embedding is not None: #Union Controlnet + control_type = kwargs.get("control_type", []) + + emb += self.control_add_embedding(control_type, emb.dtype, emb.device) + if len(control_type) > 0: + if len(hint.shape) < 5: + hint = hint.unsqueeze(dim=0) + guided_hint = self.union_controlnet_merge(hint, control_type, emb, context) + + if guided_hint is None: + guided_hint = self.input_hint_block(hint, emb, context) out_output = [] out_middle = [] diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index a236306c5..69e3cbb82 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,3 +1,4 @@ +from .component_model import files from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os import torch @@ -30,9 +31,17 @@ def clip_preprocess(image, size=224): return (image - mean.view([3,1,1])) / std.view([3,1,1]) class ClipVisionModel(): - def __init__(self, json_config): - with open(json_config) as f: - config = json.load(f) + def __init__(self, json_config: dict | str): + if isinstance(json_config, dict): + config = json_config + elif json_config is not None and isinstance(json_config, str): + if json_config.startswith("{"): + config = json.loads(json_config) + else: + with open(json_config) as f: + config = json.load(f) + else: + raise ValueError(f"json_config had invalid value={json_config}") self.load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() @@ -88,12 +97,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): if convert_keys: sd = convert_to_transformers(sd, prefix) if "vision_model.encoder.layers.47.layer_norm1.weight" in sd: - # todo: fix the importlib issue here - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") + json_config = files.get_path_as_dict(None, "clip_vision_config_g.json") elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") + json_config = files.get_path_as_dict(None, "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") + json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json") else: return None diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 07bcda3f7..b4585206e 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -85,7 +85,8 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def cleanup_temp(): try: - temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + folder_paths.get_temp_directory() + temp_dir = folder_paths.get_temp_directory() if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) except NameError: @@ -115,7 +116,7 @@ async def main(): # configure extra model paths earlier try: - extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") + extra_model_paths_config_path = os.path.join(os.getcwd(), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): load_extra_path_config(extra_model_paths_config_path) except NameError: diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 3e13c2348..412036fdc 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -439,6 +439,7 @@ class PromptServer(ExecutorToClientProgress): info['name'] = node_class info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else '' + info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") info['category'] = 'sd' if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: info['output_node'] = True @@ -845,18 +846,9 @@ class PromptServer(ExecutorToClientProgress): return json_data - @classmethod - def get_output_path(cls, subfolder: str | None = None, filename: str | None = None): - paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""] - return os.path.join(os.path.dirname(os.path.realpath(__file__)), *paths) - @classmethod def get_upload_dir(cls) -> str: - upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../input") - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) - return upload_dir + return folder_paths.get_input_directory() @classmethod def get_too_busy_queue_size(cls): diff --git a/comfy/component_model/files.py b/comfy/component_model/files.py index c4c9dfff8..636ecfee2 100644 --- a/comfy/component_model/files.py +++ b/comfy/component_model/files.py @@ -19,7 +19,7 @@ def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_ config: dict | None = None if config_dict_or_path is None: - config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package) + config_dict_or_path = config_path_inside_package if isinstance(config_dict_or_path, str): if config_dict_or_path.startswith("{"): diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 291767ff6..8ae92e108 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -412,6 +412,12 @@ def load_controlnet(ckpt_path, model=None): if k in controlnet_data: new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet + controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0] + for k in list(controlnet_data.keys()): + new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') + new_sd[new_k] = controlnet_data.pop(k) + leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: logging.warning("leftover keys: {}".format(leftover_keys)) diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py new file mode 100644 index 000000000..c465619bd --- /dev/null +++ b/comfy/ldm/aura/mmdit.py @@ -0,0 +1,479 @@ +#AuraFlow MMDiT +#Originally written by the AuraFlow Authors + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None: + super().__init__() + if hidden_dim is None: + hidden_dim = 4 * dim + + n_hidden = int(2 * hidden_dim / 3) + n_hidden = find_multiple(n_hidden, 256) + + self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device) + self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device) + self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class MultiHeadLayerNorm(nn.Module): + def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None): + # Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78 + + super().__init__() + self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt( + variance + self.variance_epsilon + ) + hidden_states = self.weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + +class SingleAttention(nn.Module): + def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None): + super().__init__() + + self.n_heads = n_heads + self.head_dim = dim // n_heads + + # this is for cond + self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + self.q_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + self.k_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + + #@torch.compile() + def forward(self, c): + + bsz, seqlen1, _ = c.shape + + q, k, v = self.w1q(c), self.w1k(c), self.w1v(c) + q = q.view(bsz, seqlen1, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen1, self.n_heads, self.head_dim) + v = v.view(bsz, seqlen1, self.n_heads, self.head_dim) + q, k = self.q_norm1(q), self.k_norm1(k) + + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + c = self.w1o(output) + return c + + + +class DoubleAttention(nn.Module): + def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None): + super().__init__() + + self.n_heads = n_heads + self.head_dim = dim // n_heads + + # this is for cond + self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + # this is for x + self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + self.q_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + self.k_norm1 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + + self.q_norm2 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + self.k_norm2 = ( + MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device) + if mh_qknorm + else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device) + ) + + + #@torch.compile() + def forward(self, c, x): + + bsz, seqlen1, _ = c.shape + bsz, seqlen2, _ = x.shape + seqlen = seqlen1 + seqlen2 + + cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c) + cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim) + ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim) + cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim) + cq, ck = self.q_norm1(cq), self.k_norm1(ck) + + xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x) + xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim) + xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim) + xq, xk = self.q_norm2(xq), self.k_norm2(xk) + + # concat all + q, k, v = ( + torch.cat([cq, xq], dim=1), + torch.cat([ck, xk], dim=1), + torch.cat([cv, xv], dim=1), + ) + + output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) + + c, x = output.split([seqlen1, seqlen2], dim=1) + c = self.w1o(c) + x = self.w2o(x) + + return c, x + + +class MMDiTBlock(nn.Module): + def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None): + super().__init__() + + self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + if not is_last: + self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) + self.modC = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device), + ) + else: + self.modC = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device), + ) + + self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) + self.modX = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device), + ) + + self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations) + self.is_last = is_last + + #@torch.compile() + def forward(self, c, x, global_cond, **kwargs): + + cres, xres = c, x + + cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = ( + self.modC(global_cond).chunk(6, dim=1) + ) + + c = modulate(self.normC1(c), cshift_msa, cscale_msa) + + # xpath + xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = ( + self.modX(global_cond).chunk(6, dim=1) + ) + + x = modulate(self.normX1(x), xshift_msa, xscale_msa) + + # attention + c, x = self.attn(c, x) + + + c = self.normC2(cres + cgate_msa.unsqueeze(1) * c) + c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp)) + c = cres + c + + x = self.normX2(xres + xgate_msa.unsqueeze(1) * x) + x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp)) + x = xres + x + + return c, x + +class DiTBlock(nn.Module): + # like MMDiTBlock, but it only has X + def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None): + super().__init__() + + self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device) + + self.modCX = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device), + ) + + self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations) + + #@torch.compile() + def forward(self, cx, global_cond, **kwargs): + cxres = cx + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX( + global_cond + ).chunk(6, dim=1) + cx = modulate(self.norm1(cx), shift_msa, scale_msa) + cx = self.attn(cx) + cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx) + mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) + cx = gate_mlp.unsqueeze(1) * mlpout + + cx = cxres + cx + + return cx + + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = 1000 * torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half) / half + ).to(t.device) + args = t[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + #@torch.compile() + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class MMDiT(nn.Module): + def __init__( + self, + in_channels=4, + out_channels=4, + patch_size=2, + dim=3072, + n_layers=36, + n_double_layers=4, + n_heads=12, + global_conddim=3072, + cond_seq_dim=2048, + max_seq=32 * 32, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations) + + self.cond_seq_linear = operations.Linear( + cond_seq_dim, dim, bias=False, dtype=dtype, device=device + ) # linear for something like text sequence. + self.init_x_linear = operations.Linear( + patch_size * patch_size * in_channels, dim, dtype=dtype, device=device + ) # init linear for patchified image. + + self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device)) + self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device)) + + self.double_layers = nn.ModuleList([]) + self.single_layers = nn.ModuleList([]) + + + for idx in range(n_double_layers): + self.double_layers.append( + MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations) + ) + + for idx in range(n_double_layers, n_layers): + self.single_layers.append( + DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations) + ) + + + self.final_linear = operations.Linear( + dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device + ) + + self.modF = nn.Sequential( + nn.SiLU(), + operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device), + ) + + self.out_channels = out_channels + self.patch_size = patch_size + self.n_double_layers = n_double_layers + self.n_layers = n_layers + + self.h_max = round(max_seq**0.5) + self.w_max = round(max_seq**0.5) + + @torch.no_grad() + def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)): + # extend pe + pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]] + + pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1) + + # now we need to extend this to target_dim. for this we will use interpolation. + # we will use torch.nn.functional.interpolate + pe_as_2d = F.interpolate( + pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear" + ) + pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1) + self.positional_encoding.data = pe_new.unsqueeze(0).contiguous() + self.h_max, self.w_max = target_dim + print("PE extended to", target_dim) + + def pe_selection_index_based_on_dim(self, h, w): + h_p, w_p = h // self.patch_size, w // self.patch_size + original_pe_indexes = torch.arange(self.positional_encoding.shape[1]) + original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max) + starth = self.h_max // 2 - h_p // 2 + endh =starth + h_p + startw = self.w_max // 2 - w_p // 2 + endw = startw + w_p + original_pe_indexes = original_pe_indexes[ + starth:endh, startw:endw + ] + return original_pe_indexes.flatten() + + def unpatchify(self, x, h, w): + c = self.out_channels + p = self.patch_size + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def patchify(self, x): + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + x = x.view( + B, + C, + (H + 1) // self.patch_size, + self.patch_size, + (W + 1) // self.patch_size, + self.patch_size, + ) + x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + return x + + def apply_pos_embeds(self, x, h, w): + h = (h + 1) // self.patch_size + w = (w + 1) // self.patch_size + max_dim = max(h, w) + + cur_dim = self.h_max + pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) + + if max_dim > cur_dim: + pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1) + cur_dim = max_dim + + from_h = (cur_dim - h) // 2 + from_w = (cur_dim - w) // 2 + pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w] + return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) + + def forward(self, x, timestep, context, **kwargs): + # patchify x, add PE + b, c, h, w = x.shape + + # pe_indexes = self.pe_selection_index_based_on_dim(h, w) + # print(pe_indexes, pe_indexes.shape) + + x = self.init_x_linear(self.patchify(x)) # B, T_x, D + x = self.apply_pos_embeds(x, h, w) + # x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype) + # x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype) + + # process conditions for MMDiT Blocks + c_seq = context # B, T_c, D_c + t = timestep + + c = self.cond_seq_linear(c_seq) # B, T_c, D + c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1) + + global_cond = self.t_embedder(t, x.dtype) # B, D + + if len(self.double_layers) > 0: + for layer in self.double_layers: + c, x = layer(c, x, global_cond, **kwargs) + + if len(self.single_layers) > 0: + c_len = c.size(1) + cx = torch.cat([c, x], dim=1) + for layer in self.single_layers: + cx = layer(cx, global_cond, **kwargs) + + x = cx[:, c_len:] + + fshift, fscale = self.modF(global_cond).chunk(2, dim=1) + + x = modulate(x, fshift, fscale) + x = self.final_linear(x) + x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w] + return x diff --git a/comfy/lora.py b/comfy/lora.py index 584b4fff2..b70debce5 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -272,4 +272,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora key_map[key_lora] = to + if isinstance(model, model_base.AuraFlow): #Diffusers lora AuraFlow + diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format + key_map[key_lora] = to + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 3c8c13c7e..e68d6b5e0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -17,7 +17,7 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation - +from .ldm.aura.mmdit import MMDiT as AuraMMDiT class ModelType(Enum): EPS = 1 @@ -622,6 +622,17 @@ class SD3(BaseModel): area = input_shape[0] * input_shape[2] * input_shape[3] return (area * 0.3) * (1024 * 1024) +class AuraFlow(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=AuraMMDiT) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = conds.CONDRegular(cross_attn) + return out + class StableAudio1(BaseModel): def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 044245a6f..32eaa5ac3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -104,6 +104,19 @@ def detect_unet_config(state_dict, key_prefix): unet_config["audio_model"] = "dit1.0" return unet_config + if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit + unet_config = {} + unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1] + unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1] + double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.') + single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.') + unet_config["n_double_layers"] = double_layers + unet_config["n_layers"] = double_layers + single_layers + return unet_config + + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: + return None + unet_config = { "use_checkpoint": False, "image_size": 32, @@ -238,6 +251,8 @@ def model_config_from_unet_config(unet_config, state_dict=None): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): unet_config = detect_unet_config(state_dict, unet_key_prefix) + if unet_config is None: + return None model_config = model_config_from_unet_config(unet_config, state_dict) if model_config is None and use_base_if_no_match: return supported_models_base.BASE(unet_config) @@ -247,6 +262,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal def unet_prefix_from_state_dict(state_dict): if "model.model.postprocess_conv.weight" in state_dict: #audio models unet_key_prefix = "model.model." + elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow + unet_key_prefix = "model." else: unet_key_prefix = "model.diffusion_model." return unet_key_prefix @@ -436,38 +453,45 @@ def model_config_from_diffusers_unet(state_dict): return None def convert_diffusers_mmdit(state_dict, output_prefix=""): - out_sd = None - num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') - if num_blocks > 0: + out_sd = {} + + if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 + num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 - out_sd = {} sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) - for k in sd_map: - weight = state_dict.get(k, None) - if weight is not None: - t = sd_map[k] + elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow + num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.') + num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.') + sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix) + else: + return None - if not isinstance(t, str): - if len(t) > 2: - fun = t[2] - else: - fun = lambda a: a - offset = t[1] - if offset is not None: - old_weight = out_sd.get(t[0], None) - if old_weight is None: - old_weight = torch.empty_like(weight) - old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) + for k in sd_map: + weight = state_dict.get(k, None) + if weight is not None: + t = sd_map[k] - w = old_weight.narrow(offset[0], offset[1], offset[2]) - else: - old_weight = weight - w = weight - w[:] = fun(weight) - t = t[0] - out_sd[t] = old_weight + if not isinstance(t, str): + if len(t) > 2: + fun = t[2] else: - out_sd[t] = weight - state_dict.pop(k) + fun = lambda a: a + offset = t[1] + if offset is not None: + old_weight = out_sd.get(t[0], None) + if old_weight is None: + old_weight = torch.empty_like(weight) + old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) + + w = old_weight.narrow(offset[0], offset[1], offset[2]) + else: + old_weight = weight + w = weight + w[:] = fun(weight) + t = t[0] + out_sd[t] = old_weight + else: + out_sd[t] = weight + state_dict.pop(k) return out_sd diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b2f8ca3a6..6e9c40a95 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -60,6 +60,12 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl model_options["disable_cfg1_optimization"] = True return model_options +def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False): + model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function] + if disable_cfg1_optimization: + model_options["disable_cfg1_optimization"] = True + return model_options + class ModelPatcher(ModelManageable): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size @@ -142,6 +148,9 @@ class ModelPatcher(ModelManageable): def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) + def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False): + self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization) + def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction): self.model_options["model_function_wrapper"] = unet_wrapper_function diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 5989b2a81..019b32e12 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -192,11 +192,12 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): else: sampling_settings = {} - self.set_parameters(shift=sampling_settings.get("shift", 1.0)) + self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000)) - def set_parameters(self, shift=1.0, timesteps=1000): + def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000): self.shift = shift - ts = self.sigma(torch.arange(1, timesteps + 1, 1)) + self.multiplier = multiplier + ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier) self.register_buffer('sigmas', ts) @property @@ -208,10 +209,10 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): return self.sigmas[-1] def timestep(self, sigma): - return sigma * 1000 + return sigma * self.multiplier def sigma(self, timestep): - return time_snr_shift(self.shift, timestep / 1000) + return time_snr_shift(self.shift, timestep / self.multiplier) def percent_to_sigma(self, percent): if percent <= 0.0: diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 7a9e33a55..ab0e8ca53 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -46,8 +46,9 @@ class CLIPTextEncode: def encode(self, clip, text): tokens = clip.tokenize(text) - cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) - return ([[cond, {"pooled_output": pooled}]], ) + output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) + cond = output.pop("cond") + return ([[cond, output]], ) class ConditioningCombine: @classmethod @@ -223,8 +224,9 @@ class ConditioningZeroOut: c = [] for t in conditioning: d = t[1].copy() - if "pooled_output" in d: - d["pooled_output"] = torch.zeros_like(d["pooled_output"]) + pooled_output = d.get("pooled_output", None) + if pooled_output is not None: + d["pooled_output"] = torch.zeros_like(pooled_output) n = [torch.zeros_like(t[0]), d] c.append(n) return (c, ) diff --git a/comfy/sa_t5.py b/comfy/sa_t5.py index 37be5287e..4521c364e 100644 --- a/comfy/sa_t5.py +++ b/comfy/sa_t5.py @@ -1,22 +1,27 @@ -from comfy import sd1_clip from transformers import T5TokenizerFast + import comfy.t5 -import os +from comfy import sd1_clip +from comfy.component_model import files + class T5BaseModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_base.json") + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None): + textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json") super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) + class T5BaseTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None): - tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer") super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) + class SAT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None): super().__init__(embedding_directory=embedding_directory, clip_name="t5base", tokenizer=T5BaseTokenizer) + class SAT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, clip_name="t5base", clip_model=T5BaseModel, **kwargs) + super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs) diff --git a/comfy/samplers.py b/comfy/samplers.py index 41ca60a64..11fd794e8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -278,6 +278,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option conds = [cond, uncond_] out = calc_cond_batch(model, conds, x, timestep, model_options) + + for fn in model_options.get("sampler_pre_cfg_function", []): + args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, + "input": x, "sigma": timestep, "model": model, "model_options": model_options} + out = fn(args) + return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) diff --git a/comfy/sd.py b/comfy/sd.py index c7cde70d8..ef9b73bd5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -28,37 +28,7 @@ from .t2i_adapter import adapter from .taesd import taesd from . import sd3_clip from . import sa_t5 - - -def load_model_weights(model, sd): - m, u = model.load_state_dict(sd, strict=False) - m = set(m) - unexpected_keys = set(u) - - k = list(sd.keys()) - for x in k: - if x not in unexpected_keys: - w = sd.pop(x) - del w - if len(m) > 0: - logging.warning("missing {}".format(m)) - return model - - -def load_clip_weights(model, sd): - k = list(sd.keys()) - for x in k: - if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."): - y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.") - sd[y] = sd.pop(x) - - if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd: - ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] - if ids.dtype == torch.float32: - sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - - sd = utils.clip_text_transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.") - return load_model_weights(model, sd) +from .text_encoders import aura_t5 def load_lora_for_models(model, clip, _lora, strength_model, strength_clip): @@ -136,7 +106,7 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode_from_tokens(self, tokens, return_pooled=False): + def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): self.cond_stage_model.reset_clip_options() if self.layer_idx is not None: @@ -146,7 +116,15 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() - cond, pooled = self.cond_stage_model.encode_token_weights(tokens) + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + if return_dict: + out = {"cond": cond, "pooled_output": pooled} + if len(o) > 2: + for k in o[2]: + out[k] = o[2][k] + return out + if return_pooled: return cond, pooled return cond @@ -447,9 +425,14 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]: - dtype_t5 = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"].dtype - clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) - clip_target.tokenizer = sd3_clip.SD3Tokenizer + weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] + dtype_t5 = weight.dtype + if weight.shape[-1] == 4096: + clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5) + clip_target.tokenizer = sd3_clip.SD3Tokenizer + elif weight.shape[-1] == 2048: + clip_target.clip = aura_t5.AuraT5Model + clip_target.tokenizer = aura_t5.AuraT5Tokenizer elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]: clip_target.clip = sa_t5.SAT5Model clip_target.tokenizer = sa_t5.SAT5Tokenizer @@ -529,13 +512,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o load_device = model_management.get_torch_device() model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) + if model_config is None: + raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - if model_config is None: - raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - if model_config.clip_vision_prefix is not None: if output_clipvision: clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) @@ -586,42 +569,37 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet_state_dict(sd): # load unet in diffusers or regular format #Allow loading unets from checkpoint files - checkpoint = False diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) temp_sd = utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd - checkpoint = True parameters = utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() + model_config = model_detection.model_config_from_unet(sd, "") - if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: # ldm or stable cascade - model_config = model_detection.model_config_from_unet(sd, "") - if model_config is None: - return None + if model_config is not None: new_sd = sd - elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 + else: new_sd = model_detection.convert_diffusers_mmdit(sd, "") - if new_sd is None: - return None - model_config = model_detection.model_config_from_unet(new_sd, "") - if model_config is None: - return None - else: # diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd) - if model_config is None: - return None + if new_sd is not None: #diffusers mmdit + model_config = model_detection.model_config_from_unet(new_sd, "") + if model_config is None: + return None + else: # diffusers unet + model_config = model_detection.model_config_from_diffusers_unet(sd) + if model_config is None: + return None - diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) + diffusers_keys = utils.unet_to_diffusers(model_config.unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - logging.warning("{} {}".format(diffusers_keys[k], k)) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ae0ae0f17..447726f4d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -1,11 +1,13 @@ from __future__ import annotations import copy +import importlib.resources import logging import numbers import os import traceback import zipfile +from importlib.abc import Traversable from typing import Tuple, Sequence, TypeVar import torch @@ -14,6 +16,7 @@ from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMi from . import clip_model from . import model_management from . import ops +from .component_model import files from .component_model.files import get_path_as_dict, get_package_as_path @@ -29,7 +32,58 @@ def gen_empty_tokens(special_tokens, length): output += [pad_token] * (length - len(output)) return output -class SDClipModel(torch.nn.Module): +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + to_encode = list() + max_token_len = 0 + has_weights = False + for x in token_weight_pairs: + tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) + to_encode.append(tokens) + + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + + o = self.encode(to_encode) + out, pooled = o[:2] + + if pooled is not None: + first_pooled = pooled[0:1].to(model_management.intermediate_device()) + else: + first_pooled = pooled + + output = [] + for k in range(0, sections): + z = out[k:k+1] + if has_weights: + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] + output.append(z) + + if (len(output) == 0): + r = (out[-1:].to(model_management.intermediate_device()), first_pooled) + else: + r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) + + if len(o) > 2: + extra = {} + for k in o[2]: + v = o[2][k] + if k == "attention_mask": + v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) + extra[k] = v + + r = r + (extra,) + return r + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", @@ -40,7 +94,7 @@ class SDClipModel(torch.nn.Module): def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel, special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, - return_projected_pooled=True): # clip-vit-base-patch32 + return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32 super().__init__() if special_tokens is None: special_tokens = {"start": 49406, "end": 49407, "pad": 49407} @@ -63,6 +117,7 @@ class SDClipModel(torch.nn.Module): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled + self.return_attention_masks = return_attention_masks if layer == "hidden": assert layer_idx is not None @@ -136,7 +191,7 @@ class SDClipModel(torch.nn.Module): tokens = torch.tensor(tokens, dtype=torch.long).to(device) attention_mask = None - if self.enable_attention_masks: + if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: attention_mask = torch.zeros_like(tokens) end_token = self.special_tokens.get("end", -1) for x in range(attention_mask.shape[0]): @@ -145,7 +200,11 @@ class SDClipModel(torch.nn.Module): if tokens[x, y] == end_token: break - outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + attention_mask_model = None + if self.enable_attention_masks: + attention_mask_model = attention_mask + + outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": @@ -153,7 +212,7 @@ class SDClipModel(torch.nn.Module): else: z = outputs[1].float() - if self.zero_out_masked and attention_mask is not None: + if self.zero_out_masked: z *= attention_mask.unsqueeze(-1).float() pooled_output = None @@ -163,6 +222,13 @@ class SDClipModel(torch.nn.Module): elif outputs[2] is not None: pooled_output = outputs[2].float() + extra = {} + if self.return_attention_masks: + extra["attention_mask"] = attention_mask + + if len(extra) > 0: + return z, pooled_output, extra + return z, pooled_output def encode(self, tokens): @@ -374,10 +440,13 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer') class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None): if tokenizer_path is None: - tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") - if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): + tokenizer_path = files.get_package_as_path("comfy.sd1_tokenizer") + if isinstance(tokenizer_path, Traversable): + contextlib_path = importlib.resources.as_file(tokenizer_path) + tokenizer_path = contextlib_path.__enter__() + if not tokenizer_path.endswith(".model") and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): # package based tokenizer_path = get_package_as_path('comfy.sd1_tokenizer') self.tokenizer_class = tokenizer_class @@ -395,6 +464,14 @@ class SDTokenizer: self.tokens_start = 0 self.start_token = None self.end_token = empty[0] + + if pad_token is not None: + self.pad_token = pad_token + elif pad_with_end: + self.pad_token = self.end_token + else: + self.pad_token = 0 + self.pad_with_end = pad_with_end self.pad_to_max_length = pad_to_max_length self.additional_tokens: Tuple[str, ...] = () @@ -439,10 +516,6 @@ class SDTokenizer: Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Returned list has the dimensions NxM where M is the input size of CLIP ''' - if self.pad_with_end: - pad_token = self.end_token - else: - pad_token = 0 text = escape_important(text) parsed_weights = token_weights(text, 1.0) @@ -502,7 +575,7 @@ class SDTokenizer: else: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) # start new batch batch = [] if self.start_token is not None: @@ -515,9 +588,9 @@ class SDTokenizer: # fill last batch batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) + batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) + batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens] @@ -560,10 +633,16 @@ class SD1Tokenizer: class SD1ClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs): + def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs): super().__init__() - self.clip_name = clip_name - self.clip = "clip_{}".format(self.clip_name) + + if name is not None: + self.clip_name = name + self.clip = "{}".format(self.clip_name) + else: + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs)) self.dtypes = set() @@ -578,8 +657,8 @@ class SD1ClipModel(torch.nn.Module): def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs[self.clip_name] - out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) - return out, pooled + out = getattr(self, self.clip).encode_token_weights(token_weight_pairs) + return out def load_sd(self, sd): return getattr(self, self.clip).load_sd(sd) diff --git a/comfy/sd3_clip.py b/comfy/sd3_clip.py index 0713eb285..5990ec3b9 100644 --- a/comfy/sd3_clip.py +++ b/comfy/sd3_clip.py @@ -1,39 +1,45 @@ +import logging +import os + +import torch +from transformers import T5TokenizerFast + +import comfy.model_management +import comfy.t5 from comfy import sd1_clip from comfy import sdxl_clip -from transformers import T5TokenizerFast -import comfy.t5 -import torch -import os -import comfy.model_management -import logging +from comfy.component_model import files + class T5XXLModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None): + textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json") super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5) + class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None): - tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer") super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None): super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + class SDT5XXLModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, **kwargs): super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs) - class SD3Tokenizer: def __init__(self, embedding_directory=None): self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) - def tokenize_with_weights(self, text:str, return_word_ids=False): + def tokenize_with_weights(self, text: str, return_word_ids=False): out = {} out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) @@ -43,6 +49,7 @@ class SD3Tokenizer: def untokenize(self, token_weight_pair): return self.clip_g.untokenize(token_weight_pair) + class SD3ClipModel(torch.nn.Module): def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None): super().__init__() @@ -143,8 +150,10 @@ class SD3ClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) + def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None): super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype) + return SD3ClipModel_ diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 21fdb7ec7..0c2994026 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -7,6 +7,7 @@ from . import sd2_clip from . import sdxl_clip from . import sd3_clip from . import sa_t5 +from .text_encoders import aura_t5 from . import supported_models_base from . import latent_formats @@ -556,7 +557,29 @@ class StableAudio(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model) +class AuraFlow(supported_models_base.BASE): + unet_config = { + "cond_seq_dim": 2048, + } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio] + sampling_settings = { + "multiplier": 1.0, + "shift": 1.73, + } + + unet_extra_config = {} + latent_format = latent_formats.SDXL + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.AuraFlow(self, device=device) + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(aura_t5.AuraT5Tokenizer, aura_t5.AuraT5Model) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow] models += [SVD_img2vid] diff --git a/comfy/t5.py b/comfy/t5.py index 06dfe4766..448c5aad3 100644 --- a/comfy/t5.py +++ b/comfy/t5.py @@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module): x = x * torch.rsqrt(variance + self.variance_epsilon) return self.weight.to(device=x.device, dtype=x.dtype) * x +activations = { + "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), + "relu": torch.nn.functional.relu, +} + class T5DenseActDense(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device, operations): + def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): super().__init__() self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) # self.dropout = nn.Dropout(config.dropout_rate) + self.act = activations[ff_activation] def forward(self, x): - x = torch.nn.functional.relu(self.wi(x)) + x = self.act(self.wi(x)) # x = self.dropout(x) x = self.wo(x) return x class T5DenseGatedActDense(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device, operations): + def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): super().__init__() self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) # self.dropout = nn.Dropout(config.dropout_rate) + self.act = activations[ff_activation] def forward(self, x): - hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_gelu = self.act(self.wi_0(x)) hidden_linear = self.wi_1(x) x = hidden_gelu * hidden_linear # x = self.dropout(x) @@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module): return x class T5LayerFF(torch.nn.Module): - def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): + def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations): super().__init__() - if ff_activation == "gelu_pytorch_tanh": - self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations) - elif ff_activation == "relu": - self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations) + if gated_act: + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations) + else: + self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations) self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) @@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module): return x, past_bias class T5Block(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations): + def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations): super().__init__() self.layer = torch.nn.ModuleList() self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations)) - self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations)) + self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations)) def forward(self, x, mask=None, past_bias=None, optimized_attention=None): x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention) @@ -183,11 +190,11 @@ class T5Block(torch.nn.Module): return x, past_bias class T5Stack(torch.nn.Module): - def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention, dtype, device, operations): super().__init__() self.block = torch.nn.ModuleList( - [T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)] + [T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0)), dtype=dtype, device=device, operations=operations) for i in range(num_layers)] ) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) @@ -216,7 +223,7 @@ class T5(torch.nn.Module): self.num_layers = config_dict["num_layers"] model_dim = config_dict["d_model"] - self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations) + self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] == "t5", dtype, device, operations) self.dtype = dtype self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device) diff --git a/comfy/t5_config_base.json b/comfy/t5_config_base.json index facd85ef3..71f68327c 100644 --- a/comfy/t5_config_base.json +++ b/comfy/t5_config_base.json @@ -8,6 +8,7 @@ "dense_act_fn": "relu", "initializer_factor": 1.0, "is_encoder_decoder": true, + "is_gated_act": false, "layer_norm_epsilon": 1e-06, "model_type": "t5", "num_decoder_layers": 12, diff --git a/comfy/t5_config_xxl.json b/comfy/t5_config_xxl.json index bf4feadcf..28283b51a 100644 --- a/comfy/t5_config_xxl.json +++ b/comfy/t5_config_xxl.json @@ -8,6 +8,7 @@ "dense_act_fn": "gelu_pytorch_tanh", "initializer_factor": 1.0, "is_encoder_decoder": true, + "is_gated_act": true, "layer_norm_epsilon": 1e-06, "model_type": "t5", "num_decoder_layers": 24, diff --git a/comfy/text_encoders/__init__.py b/comfy/text_encoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py new file mode 100644 index 000000000..94ebd868b --- /dev/null +++ b/comfy/text_encoders/aura_t5.py @@ -0,0 +1,28 @@ +from importlib import resources + +from comfy import sd1_clip +from .llama_tokenizer import LLAMATokenizer +from .. import t5 +from ..component_model.files import get_path_as_dict + + +class PT5XlModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None): + textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_pile_config_xl.json", package="comfy.text_encoders") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=t5.T5, enable_attention_masks=True, zero_out_masked=True) + + +class PT5XlTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None): + tokenizer_path = resources.files("comfy.text_encoders.t5_pile_tokenizer") / "tokenizer.model" + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LLAMATokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) + + +class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer) + + +class AuraT5Model(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs) diff --git a/comfy/text_encoders/llama_tokenizer.py b/comfy/text_encoders/llama_tokenizer.py new file mode 100644 index 000000000..a6db1da62 --- /dev/null +++ b/comfy/text_encoders/llama_tokenizer.py @@ -0,0 +1,22 @@ +import os + +class LLAMATokenizer: + @staticmethod + def from_pretrained(path): + return LLAMATokenizer(path) + + def __init__(self, tokenizer_path): + import sentencepiece + self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path) + self.end = self.tokenizer.eos_id() + + def get_vocab(self): + out = {} + for i in range(self.tokenizer.get_piece_size()): + out[self.tokenizer.id_to_piece(i)] = i + return out + + def __call__(self, string): + out = self.tokenizer.encode(string) + out += [self.end] + return {"input_ids": out} diff --git a/comfy/text_encoders/t5_pile_config_xl.json b/comfy/text_encoders/t5_pile_config_xl.json new file mode 100644 index 000000000..ee4e03f97 --- /dev/null +++ b/comfy/text_encoders/t5_pile_config_xl.json @@ -0,0 +1,22 @@ +{ + "d_ff": 5120, + "d_kv": 64, + "d_model": 2048, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 2, + "dense_act_fn": "gelu_pytorch_tanh", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "umt5", + "num_decoder_layers": 24, + "num_heads": 32, + "num_layers": 24, + "output_past": true, + "pad_token_id": 1, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "vocab_size": 32128 +} diff --git a/comfy/text_encoders/t5_pile_tokenizer/__init__.py b/comfy/text_encoders/t5_pile_tokenizer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/text_encoders/t5_pile_tokenizer/tokenizer.model b/comfy/text_encoders/t5_pile_tokenizer/tokenizer.model new file mode 100644 index 000000000..22bccbcb4 Binary files /dev/null and b/comfy/text_encoders/t5_pile_tokenizer/tokenizer.model differ diff --git a/comfy/utils.py b/comfy/utils.py index c2ee3422e..11884b311 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,6 +20,8 @@ from PIL import Image from tqdm import tqdm from . import checkpoint_pickle, interruption +from .component_model import files +from .component_model.deprecation import _deprecate_method from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage from .component_model.queue_types import BinaryEventTypes from .execution_context import current_execution_context @@ -374,6 +376,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): return key_map +def auraflow_to_diffusers(mmdit_config, output_prefix=""): + n_double_layers = mmdit_config.get("n_double_layers", 0) + n_layers = mmdit_config.get("n_layers", 0) + + key_map = {} + for i in range(n_layers): + if i < n_double_layers: + index = i + prefix_from = "joint_transformer_blocks" + prefix_to = "{}double_layers".format(output_prefix) + block_map = { + "attn.to_q.weight": "attn.w2q.weight", + "attn.to_k.weight": "attn.w2k.weight", + "attn.to_v.weight": "attn.w2v.weight", + "attn.to_out.0.weight": "attn.w2o.weight", + "attn.add_q_proj.weight": "attn.w1q.weight", + "attn.add_k_proj.weight": "attn.w1k.weight", + "attn.add_v_proj.weight": "attn.w1v.weight", + "attn.to_add_out.weight": "attn.w1o.weight", + "ff.linear_1.weight": "mlpX.c_fc1.weight", + "ff.linear_2.weight": "mlpX.c_fc2.weight", + "ff.out_projection.weight": "mlpX.c_proj.weight", + "ff_context.linear_1.weight": "mlpC.c_fc1.weight", + "ff_context.linear_2.weight": "mlpC.c_fc2.weight", + "ff_context.out_projection.weight": "mlpC.c_proj.weight", + "norm1.linear.weight": "modX.1.weight", + "norm1_context.linear.weight": "modC.1.weight", + } + else: + index = i - n_double_layers + prefix_from = "single_transformer_blocks" + prefix_to = "{}single_layers".format(output_prefix) + + block_map = { + "attn.to_q.weight": "attn.w1q.weight", + "attn.to_k.weight": "attn.w1k.weight", + "attn.to_v.weight": "attn.w1v.weight", + "attn.to_out.0.weight": "attn.w1o.weight", + "norm1.linear.weight": "modCX.1.weight", + "ff.linear_1.weight": "mlp.c_fc1.weight", + "ff.linear_2.weight": "mlp.c_fc2.weight", + "ff.out_projection.weight": "mlp.c_proj.weight" + } + + for k in block_map: + key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k]) + + MAP_BASIC = { + ("positional_encoding", "pos_embed.pos_embed"), + ("register_tokens", "register_tokens"), + ("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"), + ("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"), + ("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"), + ("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"), + ("cond_seq_linear.weight", "context_embedder.weight"), + ("init_x_linear.weight", "pos_embed.proj.weight"), + ("init_x_linear.bias", "pos_embed.proj.bias"), + ("final_linear.weight", "proj_out.weight"), + ("modF.1.weight", "norm_out.linear.weight", swap_scale_shift), + } + + for k in MAP_BASIC: + if len(k) > 2: + key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2]) + else: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map + + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) @@ -675,8 +747,9 @@ class ProgressBar: self.update_absolute(self.current + value) +@_deprecate_method(version="1.0.0", message="The root project directory isn't valid when the application is installed as a package. Use os.getcwd() instead.") def get_project_root() -> str: - return os.path.join(os.path.dirname(__file__), "..") + return files.get_package_as_path("comfy") @contextmanager diff --git a/comfy/web/scripts/app.js b/comfy/web/scripts/app.js index 67e488bf9..8b4478a32 100644 --- a/comfy/web/scripts/app.js +++ b/comfy/web/scripts/app.js @@ -1599,7 +1599,7 @@ export class ComfyApp { if (json) { const workflow = JSON.parse(json); const workflowName = getStorageValue("Comfy.PreviousWorkflow"); - await this.loadGraphData(workflow, true, workflowName); + await this.loadGraphData(workflow, true, true, workflowName); return true; } }; diff --git a/comfy/web/scripts/ui/menu/workflows.js b/comfy/web/scripts/ui/menu/workflows.js index afdff538a..3b904fb4b 100644 --- a/comfy/web/scripts/ui/menu/workflows.js +++ b/comfy/web/scripts/ui/menu/workflows.js @@ -182,6 +182,11 @@ export class ComfyWorkflowsMenu { * @param {ComfyWorkflow} workflow */ async function sendToWorkflow(img, workflow) { + const openWorkflow = app.workflowManager.openWorkflows.find((w) => w.path === workflow.path); + if (openWorkflow) { + workflow = openWorkflow; + } + await workflow.load(); let options = []; const nodes = app.graph.computeExecutionOrder(false); @@ -214,7 +219,8 @@ export class ComfyWorkflowsMenu { nodeType.prototype["getExtraMenuOptions"] = function (_, options) { const r = getExtraMenuOptions?.apply?.(this, arguments); - if (app.ui.settings.getSettingValue("Comfy.UseNewMenu", false) === true) { + const setting = app.ui.settings.getSettingValue("Comfy.UseNewMenu", false); + if (setting && setting != "Disabled") { const t = /** @type { {imageIndex?: number, overIndex?: number, imgs: string[]} } */ /** @type {any} */ (this); let img; if (t.imageIndex != null) { diff --git a/comfy/web/style.css b/comfy/web/style.css index e983b652a..8ef1d0dd1 100644 --- a/comfy/web/style.css +++ b/comfy/web/style.css @@ -41,7 +41,7 @@ body { background-color: var(--bg-color); color: var(--fg-color); grid-template-columns: auto 1fr auto; - grid-template-rows: auto auto 1fr auto; + grid-template-rows: auto 1fr auto; min-height: -webkit-fill-available; max-height: -webkit-fill-available; min-width: -webkit-fill-available; @@ -49,32 +49,37 @@ body { } .comfyui-body-top { - order: 0; + order: -5; grid-column: 1/-1; z-index: 10; + display: flex; + flex-direction: column; } .comfyui-body-left { - order: 1; + order: -4; z-index: 10; + display: flex; } #graph-canvas { width: 100%; height: 100%; - order: 2; - grid-column: 1/-1; + order: -3; } .comfyui-body-right { - order: 3; + order: -2; z-index: 10; + display: flex; } .comfyui-body-bottom { - order: 4; + order: -1; grid-column: 1/-1; z-index: 10; + display: flex; + flex-direction: column; } .comfy-multiline-input { @@ -408,8 +413,12 @@ dialog::backdrop { background: rgba(0, 0, 0, 0.5); } -.comfy-dialog.comfyui-dialog { +.comfy-dialog.comfyui-dialog.comfy-modal { top: 0; + left: 0; + right: 0; + bottom: 0; + transform: none; } .comfy-dialog.comfy-modal { diff --git a/comfy_extras/nodes/nodes_audio.py b/comfy_extras/nodes/nodes_audio.py index 8ab3541ed..bfdd767a6 100644 --- a/comfy_extras/nodes/nodes_audio.py +++ b/comfy_extras/nodes/nodes_audio.py @@ -20,7 +20,7 @@ class EmptyLatentAudio: RETURN_TYPES = ("LATENT",) FUNCTION = "generate" - CATEGORY = "_for_testing/audio" + CATEGORY = "latent/audio" def generate(self, seconds): batch_size = 1 @@ -35,7 +35,7 @@ class VAEEncodeAudio: RETURN_TYPES = ("LATENT",) FUNCTION = "encode" - CATEGORY = "_for_testing/audio" + CATEGORY = "latent/audio" def encode(self, vae, audio): sample_rate = audio["sample_rate"] @@ -55,7 +55,7 @@ class VAEDecodeAudio: RETURN_TYPES = ("AUDIO",) FUNCTION = "decode" - CATEGORY = "_for_testing/audio" + CATEGORY = "latent/audio" def decode(self, vae, samples): audio = vae.decode(samples["samples"]).movedim(-1, 1) @@ -134,7 +134,7 @@ class SaveAudio: OUTPUT_NODE = True - CATEGORY = "_for_testing/audio" + CATEGORY = "audio" def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): import torchaudio # pylint: disable=import-error @@ -199,7 +199,7 @@ class LoadAudio: ] return {"required": {"audio": (sorted(files), {"audio_upload": True})}} - CATEGORY = "_for_testing/audio" + CATEGORY = "audio" RETURN_TYPES = ("AUDIO", ) FUNCTION = "load" @@ -209,7 +209,6 @@ class LoadAudio: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = torchaudio.load(audio_path) - multiplier = 1.0 audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} return (audio, ) diff --git a/comfy_extras/nodes/nodes_model_advanced.py b/comfy_extras/nodes/nodes_model_advanced.py index 8804ed8a3..d2cb94047 100644 --- a/comfy_extras/nodes/nodes_model_advanced.py +++ b/comfy_extras/nodes/nodes_model_advanced.py @@ -147,7 +147,7 @@ class ModelSamplingSD3: CATEGORY = "advanced/model" - def patch(self, model, shift): + def patch(self, model, shift, multiplier=1000): m = model.clone() sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow @@ -157,10 +157,22 @@ class ModelSamplingSD3: pass model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_parameters(shift=shift) + model_sampling.set_parameters(shift=shift, multiplier=multiplier) m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingAuraFlow(ModelSamplingSD3): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + FUNCTION = "patch_aura" + + def patch_aura(self, model, shift): + return self.patch(model, shift, multiplier=1.0) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -276,5 +288,6 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingContinuousV": ModelSamplingContinuousV, "ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingSD3": ModelSamplingSD3, + "ModelSamplingAuraFlow": ModelSamplingAuraFlow, "RescaleCFG": RescaleCFG, } diff --git a/requirements.txt b/requirements.txt index 21c1b360d..49941d635 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,8 @@ torchsde>=0.2.6 einops>=0.6.0 open-clip-torch>=2.24.0 transformers>=4.29.1 +tokenizers>=0.13.3 +sentencepiece peft torchinfo safetensors>=0.4.2 diff --git a/setup.py b/setup.py index ce6b4c422..0da0c8428 100644 --- a/setup.py +++ b/setup.py @@ -192,6 +192,7 @@ package_data = [ 't5_tokenizer/*', '**/*.json', '**/*.yaml', + '**/*.model' ] if not is_editable: package_data.append('comfy/web/**/*') diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 8df9d6d36..26d455956 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor import jwt import pytest -from aiohttp import ClientSession, ClientConnectorError +from aiohttp import ClientSession from testcontainers.rabbitmq import RabbitMqContainer from comfy.client.aio_client import AsyncRemoteComfyClient @@ -132,13 +132,11 @@ async def test_basic_queue_worker_with_health_check(): health_check_port = 9090 async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker: - # Test health check health_check_url = f"http://localhost:{health_check_port}/health" health_check_ok = await check_health(health_check_url) assert health_check_ok, "Health check server did not start properly" - # Test the actual worker functionality from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) await distributed_queue.init() @@ -153,53 +151,5 @@ async def test_basic_queue_worker_with_health_check(): await distributed_queue.close() - # Test that the health check server is stopped after the worker is closed health_check_stopped = not await check_health(health_check_url, max_retries=1) - assert health_check_stopped, "Health check server did not stop properly" - - -@pytest.mark.asyncio -async def test_health_check_port_conflict(): - with RabbitMqContainer("rabbitmq:latest") as rabbitmq: - params = rabbitmq.get_connection_params() - connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" - health_check_port = 9090 - - # Start a simple server to occupy the health check port - from aiohttp import web - async def dummy_handler(request): - return web.Response(text="Dummy") - - app = web.Application() - app.router.add_get('/', dummy_handler) - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite(runner, '0.0.0.0', health_check_port) - await site.start() - - try: - # Now try to start the DistributedPromptWorker - async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker: - # The health check should be disabled, but the worker should still function - from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue - distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) - await distributed_queue.init() - - queue_item = create_test_prompt() - res = await distributed_queue.put_async(queue_item) - - assert res.item_id == queue_item.prompt_id - assert len(res.outputs) == 1 - assert res.status is not None - assert res.status.status_str == "success" - - await distributed_queue.close() - - # The original server should still be running - async with ClientSession() as session: - async with session.get(f"http://localhost:{health_check_port}") as response: - assert response.status == 200 - assert await response.text() == "Dummy" - - finally: - await runner.cleanup() + assert health_check_stopped, "Health check server did not stop properly" \ No newline at end of file diff --git a/tests/inference/test_workflows.py b/tests/inference/test_workflows.py index 673bcab78..cdeae9928 100644 --- a/tests/inference/test_workflows.py +++ b/tests/inference/test_workflows.py @@ -1,7 +1,6 @@ import pytest from comfy.api.components.schema.prompt import Prompt -from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.model_downloader import add_known_models, KNOWN_LORAS from comfy.model_downloader_types import CivitFile @@ -139,9 +138,7 @@ _workflows = { @pytest.fixture(scope="module", autouse=False) @pytest.mark.asyncio async def client(tmp_path_factory) -> EmbeddedComfyClient: - config = Configuration() - config.cwd = str(tmp_path_factory.mktemp("comfy_test_cwd")) - async with EmbeddedComfyClient(config) as client: + async with EmbeddedComfyClient() as client: yield client