mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge commit '39fb74c5bd13a1dccf4d7293a2f7a755d9f43cbd' of github.com:comfyanonymous/ComfyUI
- Improvements to tests - Fixes model management - Fixes issues with language nodes
This commit is contained in:
commit
0549f35e85
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
@ -0,0 +1,53 @@
|
||||
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
|
||||
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||
name: Pull Request CI Workflow Runs
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [labeled]
|
||||
|
||||
jobs:
|
||||
pr-test-stable:
|
||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux, windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
use_prior_commit: 'true'
|
||||
comment:
|
||||
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||
})
|
||||
95
.github/workflows/test-ci.yml
vendored
Normal file
95
.github/workflows/test-ci.yml
vendored
Normal file
@ -0,0 +1,95 @@
|
||||
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
|
||||
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||
name: Full Comfy CI Workflow Runs
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'app/**'
|
||||
- 'input/**'
|
||||
- 'output/**'
|
||||
- 'notebooks/**'
|
||||
- 'script_examples/**'
|
||||
- '.github/**'
|
||||
- 'web/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test-stable:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux, windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-win-nightly:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [windows]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: windows
|
||||
runner_label: [self-hosted, win]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
runs-on: ${{ matrix.runner_label }}
|
||||
steps:
|
||||
- name: Test Workflows
|
||||
uses: comfy-org/comfy-action@main
|
||||
with:
|
||||
os: ${{ matrix.os }}
|
||||
python_version: ${{ matrix.python_version }}
|
||||
torch_version: ${{ matrix.torch_version }}
|
||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||
comfyui_flags: ${{ matrix.flags }}
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -176,4 +176,5 @@ cython_debug/
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
*.log
|
||||
web_custom_versions/
|
||||
web_custom_versions/
|
||||
.DS_Store
|
||||
|
||||
@ -14,9 +14,9 @@ from typing import TypedDict
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
from comfy.cmd.folder_paths import add_model_folder_path
|
||||
from comfy.component_model.files import get_package_as_path
|
||||
from ..cli_args import DEFAULT_VERSION_STRING
|
||||
from ..cmd.folder_paths import add_model_folder_path
|
||||
from ..component_model.files import get_package_as_path
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ import time
|
||||
import traceback
|
||||
import typing
|
||||
from os import PathLike
|
||||
from pathlib import PurePath
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import lazy_object_proxy
|
||||
@ -34,9 +33,11 @@ from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOption
|
||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||
# various functions that are declared here. It should have been a context in the first place.
|
||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy.graph_utils import is_link, GraphBuilder
|
||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
|
||||
# order matters
|
||||
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from ..graph_utils import is_link, GraphBuilder
|
||||
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
|
||||
|
||||
class IsChangedCache:
|
||||
@ -446,6 +447,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted
|
||||
}
|
||||
|
||||
if isinstance(ex, model_management.OOM_EXCEPTION):
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
model_management.unload_all_models()
|
||||
|
||||
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
@ -55,7 +55,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
logging.debug("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
@ -838,7 +838,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.port = port
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
logging.info("Starting server")
|
||||
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
|
||||
if call_on_start is not None:
|
||||
call_on_start(address, port)
|
||||
|
||||
@ -1,4 +1,24 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
@ -12,7 +32,8 @@ from . import latent_formats
|
||||
|
||||
from .cldm import cldm, mmdit
|
||||
from .t2i_adapter import adapter
|
||||
from .ldm.cascade import controlnet
|
||||
from .ldm import hydit, flux
|
||||
from .ldm.cascade import controlnet as cascade_controlnet
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@ -33,6 +54,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
else:
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
self.cond_hint_original = None
|
||||
@ -51,6 +76,8 @@ class ControlBase:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -93,6 +120,8 @@ class ControlBase:
|
||||
c.latent_format = self.latent_format
|
||||
c.extra_args = self.extra_args.copy()
|
||||
c.vae = self.vae
|
||||
c.extra_conds = self.extra_conds.copy()
|
||||
c.strength_type = self.strength_type
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@ -113,7 +142,10 @@ class ControlBase:
|
||||
|
||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||
applied_to.add(x)
|
||||
x *= self.strength
|
||||
if self.strength_type == StrengthType.CONSTANT:
|
||||
x *= self.strength
|
||||
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||
x *= (self.strength ** float(len(control_output) - i))
|
||||
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
@ -142,7 +174,7 @@ class ControlBase:
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, ckpt_name: str = None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
@ -154,6 +186,8 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
self.latent_format = latent_format
|
||||
self.extra_conds += extra_conds
|
||||
self.strength_type = strength_type
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@ -191,13 +225,16 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(dtype)
|
||||
extra = self.extra_args.copy()
|
||||
for c in self.extra_conds:
|
||||
temp = cond.get(c, None)
|
||||
if temp is not None:
|
||||
extra[c] = temp.to(dtype)
|
||||
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
@ -286,6 +323,7 @@ class ControlLora(ControlNet):
|
||||
ControlBase.__init__(self, device)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.extra_conds += ["y"]
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
@ -338,12 +376,8 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "", True)
|
||||
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
def controlnet_config(sd):
|
||||
model_config = model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
|
||||
@ -356,14 +390,27 @@ def load_controlnet_mmdit(sd):
|
||||
else:
|
||||
operations = ops.disable_weight_init
|
||||
|
||||
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
@ -371,8 +418,30 @@ def load_controlnet_mmdit(sd):
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||
|
||||
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
|
||||
latent_format = latent_formats.SDXL()
|
||||
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
||||
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
return load_controlnet_hunyuandit(controlnet_data)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
|
||||
@ -430,7 +499,10 @@ def load_controlnet(ckpt_path, model=None):
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs(controlnet_data)
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
@ -590,11 +662,11 @@ def load_t2i_adapter(t2i_data):
|
||||
xl = True
|
||||
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||
elif "backbone.0.0.weight" in keys:
|
||||
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||
compression_ratio = 32
|
||||
upscale_algorithm = 'bilinear'
|
||||
elif "backbone.10.blocks.0.weight" in keys:
|
||||
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||
compression_ratio = 1
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
else:
|
||||
|
||||
@ -28,7 +28,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
|
||||
unet = None
|
||||
if unet_path is not None:
|
||||
unet = sd.load_unet(unet_path)
|
||||
unet = sd.load_diffusion_model(unet_path)
|
||||
|
||||
clip = None
|
||||
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])
|
||||
|
||||
@ -8,7 +8,7 @@ from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
import comfy.model_patcher
|
||||
from .. import model_patcher
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
@ -1032,7 +1032,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
@ -1058,7 +1058,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
||||
@ -77,19 +77,9 @@ class TransformersManagedModel(ModelManageable):
|
||||
|
||||
return self.model.config.to_dict()
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self):
|
||||
return 0
|
||||
|
||||
@lowvram_patch_counter.setter
|
||||
def lowvram_patch_counter(self, value: int):
|
||||
warnings.warn("Not supported")
|
||||
pass
|
||||
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: PreTrainedModel
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return self.model.device
|
||||
@ -127,12 +117,10 @@ class TransformersManagedModel(ModelManageable):
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
return self.model.to(device=offload_device)
|
||||
|
||||
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
|
||||
@ -177,7 +165,7 @@ class TransformersManagedModel(ModelManageable):
|
||||
self.processor.to(device=self.load_device)
|
||||
|
||||
assert "<image>" in prompt.lower(), "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
||||
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
|
||||
batch_feature: BatchFeature = self.processor([prompt], images=images.unbind(), padding=True, return_tensors="pt")
|
||||
if hasattr(self.processor, "to"):
|
||||
self.processor.to(device=self.offload_device)
|
||||
assert "input_ids" in batch_feature
|
||||
|
||||
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
@ -0,0 +1,104 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
# controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
return {"output": (controlnet_block_res_samples * 10)[:19]}
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
hint = hint * 2.0 - 1.0
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
||||
@ -2,12 +2,12 @@ import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
from ... import ops
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
super().__init__()
|
||||
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
@ -48,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@ -94,14 +91,6 @@ class SelfAttention(nn.Module):
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
@ -163,22 +152,21 @@ class DoubleStreamBlock(nn.Module):
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@ -186,8 +174,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = txt.clip(-65504, 65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
@ -232,14 +224,17 @@ class SingleStreamBlock(nn.Module):
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = x.clip(-65504, 65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
|
||||
@ -38,7 +38,7 @@ class Flux(nn.Module):
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
@ -83,7 +83,8 @@ class Flux(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@ -94,6 +95,7 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@ -112,8 +114,15 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
for i in range(len(self.double_blocks)):
|
||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: #Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
@ -123,7 +132,7 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@ -138,5 +147,5 @@ class Flux(nn.Module):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
@ -78,10 +78,9 @@ def apply_rotary_emb(
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xq_out = (xq * cos + rotate_half(xq) * sin)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
xk_out = (xk * cos + rotate_half(xk) * sin)
|
||||
else:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||
|
||||
321
comfy/ldm/hydit/controlnet.py
Normal file
321
comfy/ldm/hydit/controlnet.py
Normal file
@ -0,0 +1,321 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||
Mlp,
|
||||
TimestepEmbedder,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from .poolers import AttentionPool
|
||||
|
||||
import comfy.latent_formats
|
||||
from .models import HunYuanDiTBlock, calc_rope
|
||||
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
|
||||
class HunYuanControlNet(nn.Module):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||
|
||||
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: argparse.Namespace
|
||||
The arguments parsed by argparse.
|
||||
input_size: tuple
|
||||
The size of the input image.
|
||||
patch_size: int
|
||||
The size of the patch.
|
||||
in_channels: int
|
||||
The number of input channels.
|
||||
hidden_size: int
|
||||
The hidden size of the transformer backbone.
|
||||
depth: int
|
||||
The number of transformer blocks.
|
||||
num_heads: int
|
||||
The number of attention heads.
|
||||
mlp_ratio: float
|
||||
The ratio of the hidden size of the MLP in the transformer block.
|
||||
log_fn: callable
|
||||
The logging function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: tuple = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
hidden_size: int = 1408,
|
||||
depth: int = 40,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.3637,
|
||||
text_states_dim=1024,
|
||||
text_states_dim_t5=2048,
|
||||
text_len=77,
|
||||
text_len_t5=256,
|
||||
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
||||
size_cond=False,
|
||||
use_style_cond=False,
|
||||
learn_sigma=True,
|
||||
norm="layer",
|
||||
log_fn: callable = print,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_fn = log_fn
|
||||
self.depth = depth
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.text_states_dim = text_states_dim
|
||||
self.text_states_dim_t5 = text_states_dim_t5
|
||||
self.text_len = text_len
|
||||
self.text_len_t5 = text_len_t5
|
||||
self.size_cond = size_cond
|
||||
self.use_style_cond = use_style_cond
|
||||
self.norm = norm
|
||||
self.dtype = dtype
|
||||
self.latent_format = comfy.latent_formats.SDXL
|
||||
|
||||
self.mlp_t5 = nn.Sequential(
|
||||
nn.Linear(
|
||||
self.text_states_dim_t5,
|
||||
self.text_states_dim_t5 * 4,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
self.text_states_dim_t5 * 4,
|
||||
self.text_states_dim,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
# learnable replace
|
||||
self.text_embedding_padding = nn.Parameter(
|
||||
torch.randn(
|
||||
self.text_len + self.text_len_t5,
|
||||
self.text_states_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
# Attention pooling
|
||||
pooler_out_dim = 1024
|
||||
self.pooler = AttentionPool(
|
||||
self.text_len_t5,
|
||||
self.text_states_dim_t5,
|
||||
num_heads=8,
|
||||
output_dim=pooler_out_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Dimension of the extra input vectors
|
||||
self.extra_in_dim = pooler_out_dim
|
||||
|
||||
if self.size_cond:
|
||||
# Image size and crop size conditions
|
||||
self.extra_in_dim += 6 * 256
|
||||
|
||||
if self.use_style_cond:
|
||||
# Here we use a default learned embedder layer for future extension.
|
||||
self.style_embedder = nn.Embedding(
|
||||
1, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.extra_in_dim += hidden_size
|
||||
|
||||
# Text embedding for `add`
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.extra_embedder = nn.Sequential(
|
||||
operations.Linear(
|
||||
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
|
||||
),
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
HunYuanDiTBlock(
|
||||
hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
text_states_dim=self.text_states_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_type=self.norm,
|
||||
skip=False,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(19)
|
||||
]
|
||||
)
|
||||
|
||||
# Input zero linear for the first block
|
||||
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
|
||||
# Output zero linear for the every block
|
||||
self.after_proj_list = nn.ModuleList(
|
||||
[
|
||||
|
||||
operations.Linear(
|
||||
self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
for _ in range(len(self.blocks))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
hint,
|
||||
timesteps,
|
||||
context,#encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
return_dict=False,
|
||||
**kwarg,
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(B, D, H, W)
|
||||
t: torch.Tensor
|
||||
(B)
|
||||
encoder_hidden_states: torch.Tensor
|
||||
CLIP text embedding, (B, L_clip, D)
|
||||
text_embedding_mask: torch.Tensor
|
||||
CLIP text embedding mask, (B, L_clip)
|
||||
encoder_hidden_states_t5: torch.Tensor
|
||||
T5 text embedding, (B, L_t5, D)
|
||||
text_embedding_mask_t5: torch.Tensor
|
||||
T5 text embedding mask, (B, L_t5)
|
||||
image_meta_size: torch.Tensor
|
||||
(B, 6)
|
||||
style: torch.Tensor
|
||||
(B)
|
||||
cos_cis_img: torch.Tensor
|
||||
sin_cis_img: torch.Tensor
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
condition = hint
|
||||
if condition.shape[0] == 1:
|
||||
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
||||
|
||||
text_states = context # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||
|
||||
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||
|
||||
text_states[:, -self.text_len :] = torch.where(
|
||||
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
||||
text_states[:, -self.text_len :],
|
||||
padding[: self.text_len],
|
||||
)
|
||||
text_states_t5[:, -self.text_len_t5 :] = torch.where(
|
||||
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
|
||||
text_states_t5[:, -self.text_len_t5 :],
|
||||
padding[self.text_len :],
|
||||
)
|
||||
|
||||
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
||||
|
||||
# _, _, oh, ow = x.shape
|
||||
# th, tw = oh // self.patch_size, ow // self.patch_size
|
||||
|
||||
# Get image RoPE embedding according to `reso`lution.
|
||||
freqs_cis_img = calc_rope(
|
||||
x, self.patch_size, self.hidden_size // self.num_heads
|
||||
) # (cos_cis_img, sin_cis_img)
|
||||
|
||||
# ========================= Build time and image embedding =========================
|
||||
t = self.t_embedder(timesteps, dtype=self.dtype)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
# ========================= Concatenate all extra vectors =========================
|
||||
# Build text tokens with pooling
|
||||
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||
|
||||
# Build image meta size tokens if applicable
|
||||
# if image_meta_size is not None:
|
||||
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
|
||||
# if image_meta_size.dtype != self.dtype:
|
||||
# image_meta_size = image_meta_size.half()
|
||||
# image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||
|
||||
# Build style tokens
|
||||
if style is not None:
|
||||
style_embedding = self.style_embedder(style)
|
||||
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
# ========================= Deal with Condition =========================
|
||||
condition = self.x_embedder(condition)
|
||||
|
||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||
controls = []
|
||||
x = x + self.before_proj(condition) # add condition
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x = block(x, c, text_states, freqs_cis_img)
|
||||
controls.append(self.after_proj_list[layer](x)) # zero linear for output
|
||||
|
||||
return {"output": controls}
|
||||
@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
rope = (rope[0].to(x), rope[1].to(x))
|
||||
return rope
|
||||
|
||||
|
||||
@ -91,6 +92,8 @@ class HunYuanDiTBlock(nn.Module):
|
||||
# Long Skip Connection
|
||||
if self.skip_linear is not None:
|
||||
cat = torch.cat([x, skip], dim=-1)
|
||||
if cat.dtype != x.dtype:
|
||||
cat = cat.to(x.dtype)
|
||||
cat = self.skip_norm(cat)
|
||||
x = self.skip_linear(cat)
|
||||
|
||||
@ -362,6 +365,8 @@ class HunYuanDiT(nn.Module):
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
controls = None
|
||||
if control:
|
||||
controls = control.get("output", None)
|
||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||
skips = []
|
||||
for layer, block in enumerate(self.blocks):
|
||||
|
||||
@ -411,17 +411,17 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
logging.info("Using xformers cross attention")
|
||||
logging.debug("Using xformers cross attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch cross attention")
|
||||
logging.debug("Using pytorch cross attention")
|
||||
optimized_attention = attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
logging.info("Using split optimization for cross attention")
|
||||
logging.debug("Using split optimization for cross attention")
|
||||
optimized_attention = attention_split
|
||||
else:
|
||||
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
logging.debug("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
@ -268,13 +268,13 @@ class AttnBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
logging.debug("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
logging.debug("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
logging.debug("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from . import utils
|
||||
from . import model_base
|
||||
@ -216,11 +234,17 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map[lora_key] = k
|
||||
|
||||
for k in sdk: #OneTrainer SD3 lora
|
||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||
key_map[lora_key] = k
|
||||
|
||||
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
@ -243,6 +267,7 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
|
||||
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
|
||||
@ -1,7 +1,26 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TypeVar, Type
|
||||
|
||||
import torch
|
||||
|
||||
from . import conds
|
||||
@ -11,15 +30,16 @@ from . import ops
|
||||
from . import utils
|
||||
from .ldm.audio.dit import AudioDiffusionTransformer
|
||||
from .ldm.audio.embedders import NumberConditioner
|
||||
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
|
||||
from .ldm.cascade.stage_b import StageB
|
||||
from .ldm.cascade.stage_c import StageC
|
||||
from .ldm.flux import model as flux_model
|
||||
from .ldm.hydit.models import HunYuanDiT
|
||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
|
||||
from .ldm.hydit.models import HunYuanDiT
|
||||
from .ldm.flux import model as flux_model
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
@ -68,26 +88,33 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
TModule = TypeVar('TModule', bound=torch.nn.Module)
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device: torch.device = None, unet_model: Type[TModule] = UNetModel):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device: torch.device = device
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = ops.manual_cast
|
||||
if model_config.custom_operations is None:
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = ops.manual_cast
|
||||
else:
|
||||
operations = ops.disable_weight_init
|
||||
else:
|
||||
operations = ops.disable_weight_init
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
if model_management.force_channels_last():
|
||||
# todo: ???
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@ -96,7 +123,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.adm_channels = 0
|
||||
|
||||
self.concat_keys = ()
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
|
||||
@ -669,6 +696,7 @@ class StableAudio1(BaseModel):
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
||||
|
||||
class HunyuanDiT(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT)
|
||||
@ -701,6 +729,7 @@ class HunyuanDiT(BaseModel):
|
||||
out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux)
|
||||
|
||||
@ -136,8 +136,8 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["depth"] = 19
|
||||
dit_config["depth_single_blocks"] = 38
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
@ -494,7 +494,12 @@ def model_config_from_diffusers_unet(state_dict):
|
||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
out_sd = {}
|
||||
|
||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
sd_map = utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
@ -520,7 +525,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
old_weight = out_sd.get(t[0], None)
|
||||
if old_weight is None:
|
||||
old_weight = torch.empty_like(weight)
|
||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
||||
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
||||
exp = list(weight.shape)
|
||||
exp[offset[0]] = offset[1] + offset[2]
|
||||
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
||||
new[:old_weight.shape[0]] = old_weight
|
||||
old_weight = new
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
|
||||
@ -571,7 +571,7 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
|
||||
# fix path representation
|
||||
local_files = set(f.replace("\\", "/") for f in local_files)
|
||||
# remove .huggingface
|
||||
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface"))
|
||||
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
|
||||
# local_files.issubsetof(repo_files)
|
||||
if len(local_files) > 0 and local_files.issubset(repo_files):
|
||||
local_dirs_snapshots.append(str(local_path))
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@ -143,10 +160,10 @@ if torch.cuda.is_available() and hasattr(torch.version, "hip") and torch.version
|
||||
logging.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
logging.debug("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
||||
logging.debug("pytorch version: {}".format(torch.version.__version__))
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -171,7 +188,7 @@ else:
|
||||
pass
|
||||
try:
|
||||
XFORMERS_VERSION = xformers.version.__version__
|
||||
logging.info("xformers version: {}".format(XFORMERS_VERSION))
|
||||
logging.debug("xformers version: {}".format(XFORMERS_VERSION))
|
||||
if XFORMERS_VERSION.startswith("0.0.18"):
|
||||
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
||||
@ -263,12 +280,12 @@ if cpu_state != CPUState.GPU:
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
|
||||
logging.info(f"Set vram state to: {vram_state.name}")
|
||||
logging.debug(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||
|
||||
if DISABLE_SMART_MEMORY:
|
||||
logging.info("Disabling smart memory management")
|
||||
logging.debug("Disabling smart memory management")
|
||||
|
||||
|
||||
def get_torch_device_name(device):
|
||||
@ -288,7 +305,7 @@ def get_torch_device_name(device):
|
||||
|
||||
|
||||
try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
logging.debug("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
@ -315,9 +332,12 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_offloaded_memory(self):
|
||||
return self.model.model_size() - self.model.loaded_size()
|
||||
|
||||
def model_memory_required(self, device):
|
||||
if device == self.model.current_device:
|
||||
return 0
|
||||
if device == self.model.current_loaded_device():
|
||||
return self.model_offloaded_memory()
|
||||
else:
|
||||
return self.model_memory()
|
||||
|
||||
@ -329,15 +349,21 @@ class LoadedModel:
|
||||
|
||||
load_weights = not self.weights_loaded
|
||||
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
if self.model.loaded_size() > 0:
|
||||
use_more_vram = lowvram_model_memory
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||
@ -346,15 +372,24 @@ class LoadedModel:
|
||||
return self.real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
||||
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def model_unload(self, unpatch_weights=True):
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
self.real_model = None
|
||||
return True
|
||||
|
||||
def model_use_more_vram(self, extra_memory):
|
||||
return self.model.partially_load(self.device, extra_memory)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
@ -366,39 +401,59 @@ class LoadedModel:
|
||||
return f"<LoadedModel>"
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
extra_memory -= m.model_use_more_vram(extra_memory)
|
||||
if extra_memory <= 0:
|
||||
break
|
||||
|
||||
|
||||
def offloaded_memory(loaded_models, device):
|
||||
offloaded_mem = 0
|
||||
for m in loaded_models:
|
||||
if m.device == device:
|
||||
offloaded_mem += m.model_offloaded_memory()
|
||||
return offloaded_mem
|
||||
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 1.2
|
||||
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
||||
with model_management_lock:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
return _unload_model_clones(model, unload_weights_only, force_unload)
|
||||
|
||||
if len(to_unload) == 0:
|
||||
return True
|
||||
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
def _unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
if len(to_unload) == 0:
|
||||
return True
|
||||
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
same_weights = 0
|
||||
for i in to_unload:
|
||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||
same_weights += 1
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {}{}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
if same_weights == len(to_unload):
|
||||
unload_weight = False
|
||||
else:
|
||||
unload_weight = True
|
||||
|
||||
return unload_weight
|
||||
if not force_unload:
|
||||
if unload_weights_only and unload_weight == False:
|
||||
return None
|
||||
|
||||
for i in to_unload:
|
||||
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||
|
||||
return unload_weight
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Free Memory")
|
||||
@ -406,126 +461,158 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||
span = get_current_span()
|
||||
span.set_attribute("memory_required", memory_required)
|
||||
with model_management_lock:
|
||||
unloaded_models: List[LoadedModel] = []
|
||||
can_unload = []
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
if get_free_memory(device) > memory_required:
|
||||
break
|
||||
current_loaded_models[i].model_unload()
|
||||
unloaded_models.append(i)
|
||||
|
||||
for i in sorted(unloaded_models, reverse=True):
|
||||
current_loaded_models.pop(i)
|
||||
|
||||
if len(unloaded_models) > 0:
|
||||
soft_empty_cache()
|
||||
else:
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
|
||||
unloaded_models = _free_memory(memory_required, device, keep_loaded)
|
||||
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
|
||||
return unloaded_models
|
||||
|
||||
|
||||
def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = None
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
else:
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Load Models GPU")
|
||||
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None) -> None:
|
||||
global vram_state
|
||||
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
||||
span = get_current_span()
|
||||
if memory_required != 0:
|
||||
span.set_attribute("memory_required", memory_required)
|
||||
with model_management_lock:
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required)
|
||||
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
to_load = list(map(str, models))
|
||||
span.set_attribute("models", to_load)
|
||||
logging.info(f"Loaded {to_load}")
|
||||
|
||||
models = set(models)
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except ValueError:
|
||||
loaded_model_index = None
|
||||
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
||||
global vram_state
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
if loaded is None:
|
||||
models_to_load.append(loaded_model)
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
||||
|
||||
models = set(models)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
models_freed = []
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
models_freed: List[LoadedModel] = []
|
||||
try:
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except:
|
||||
loaded_model_index = None
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
# todo: where does 1.3 come from?
|
||||
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
if loaded is None:
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_to_load = free_memory(minimum_memory_required, d)
|
||||
models_freed += models_to_load
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
if len(models_to_load) == 0:
|
||||
return
|
||||
finally:
|
||||
span.set_attribute("models", list(map(str, models)))
|
||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||
logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) # unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_already_loaded:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_mem = get_free_memory(d)
|
||||
if free_mem > minimum_memory_required:
|
||||
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||
|
||||
span = get_current_span()
|
||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||
|
||||
|
||||
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
|
||||
@ -605,6 +692,7 @@ def unet_initial_load_device(parameters, dtype):
|
||||
def maximum_vram_for_weights(device=None):
|
||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
@ -629,12 +717,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
|
||||
if model_params * 2 > free_model_memory:
|
||||
return fp8_dtype
|
||||
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
@ -651,13 +749,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.flo
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and fp16_supported:
|
||||
return torch.float16
|
||||
if dt == torch.bfloat16 and bf16_supported:
|
||||
return torch.bfloat16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
return torch.float32
|
||||
|
||||
|
||||
def text_encoder_offload_device():
|
||||
@ -679,6 +777,21 @@ def text_encoder_device():
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
if is_device_mps(load_device):
|
||||
return offload_device
|
||||
|
||||
mem_l = get_free_memory(load_device)
|
||||
mem_o = get_free_memory(offload_device)
|
||||
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||
return load_device
|
||||
else:
|
||||
return offload_device
|
||||
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
@ -1,8 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, Optional, Any
|
||||
import dataclasses
|
||||
from typing import Protocol, Optional, TypeVar, runtime_checkable
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
|
||||
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DeviceSettable(Protocol):
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
...
|
||||
|
||||
@device.setter
|
||||
def device(self, value: torch.device):
|
||||
...
|
||||
|
||||
|
||||
class ModelManageable(Protocol):
|
||||
@ -22,32 +37,92 @@ class ModelManageable(Protocol):
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
...
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
...
|
||||
def is_clone(self, other: ModelManageableT) -> bool:
|
||||
return other.model is self.model
|
||||
|
||||
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
||||
...
|
||||
def clone_has_same_weights(self, clone: ModelManageableT) -> bool:
|
||||
return clone.model is self.model
|
||||
|
||||
def model_size(self) -> int:
|
||||
...
|
||||
from .model_management import module_size
|
||||
return module_size(self.model)
|
||||
|
||||
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||
...
|
||||
pass
|
||||
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
...
|
||||
return next(self.model.parameters()).dtype
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
self.patch_model(device_to=device_to, patch_weights=False)
|
||||
return self.model
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||
"""
|
||||
Loads the model to the device
|
||||
:param device_to: the device to move the model weights to
|
||||
:param patch_weights: True if the patch's weights should also be moved
|
||||
:return:
|
||||
"""
|
||||
...
|
||||
|
||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
"""
|
||||
Unloads the model by moving it to the offload device
|
||||
:param offload_device:
|
||||
:param unpatch_weights:
|
||||
:return:
|
||||
"""
|
||||
...
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
...
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return 0
|
||||
|
||||
def partially_load(self, device_to: torch.device, extra_memory=0) -> int:
|
||||
self.patch_model(device_to=device_to)
|
||||
return self.model_size()
|
||||
|
||||
def partially_unload(self, device_to: torch.device, extra_memory=0) -> int:
|
||||
self.unpatch_model(device_to)
|
||||
return self.model_size()
|
||||
|
||||
def memory_required(self, input_shape) -> int:
|
||||
from comfy.model_base import BaseModel
|
||||
|
||||
if isinstance(self.model, BaseModel):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
else:
|
||||
# todo: why isn't this true?
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self) -> int:
|
||||
if self.current_loaded_device() == self.load_device:
|
||||
return self.model_size()
|
||||
return 0
|
||||
|
||||
def current_loaded_device(self) -> torch.device:
|
||||
return self.current_device
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryMeasurements:
|
||||
model: torch.nn.Module | DeviceSettable
|
||||
model_loaded_weight_memory: int = 0
|
||||
lowvram_patch_counter: int = 0
|
||||
model_lowvram: bool = False
|
||||
_device: torch.device | None = None
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
...
|
||||
def device(self) -> torch.device:
|
||||
if isinstance(self.model, DeviceSettable):
|
||||
return self.model.device
|
||||
else:
|
||||
return self._device
|
||||
|
||||
@device.setter
|
||||
def device(self, value: torch.device):
|
||||
if isinstance(self.model, DeviceSettable):
|
||||
self.model.device = value
|
||||
self._device = value
|
||||
|
||||
@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
@ -5,10 +23,12 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
|
||||
from . import model_management
|
||||
from . import utils
|
||||
from .model_management_types import ModelManageable
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements
|
||||
from .types import UnetWrapperFunction
|
||||
|
||||
|
||||
@ -69,10 +89,27 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
return model_options
|
||||
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.model: torch.nn.Module = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
@ -81,25 +118,21 @@ class ModelPatcher(ModelManageable):
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self._current_device: torch.device
|
||||
if current_device is None:
|
||||
self._current_device = self.offload_device
|
||||
else:
|
||||
self._current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.ckpt_name = ckpt_name
|
||||
self._lowvram_patch_counter = 0
|
||||
self._memory_measurements = MemoryMeasurements(self.model)
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self):
|
||||
return self._lowvram_patch_counter
|
||||
def model_device(self) -> torch.device:
|
||||
return self._memory_measurements.device
|
||||
|
||||
@lowvram_patch_counter.setter
|
||||
def lowvram_patch_counter(self, value: int):
|
||||
self._lowvram_patch_counter = value
|
||||
@model_device.setter
|
||||
def model_device(self, value: torch.device):
|
||||
self._memory_measurements.device = value
|
||||
|
||||
def lowvram_patch_counter(self):
|
||||
return self._memory_measurements.lowvram_patch_counter
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
@ -107,8 +140,12 @@ class ModelPatcher(ModelManageable):
|
||||
self.size = model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def loaded_size(self):
|
||||
return self._memory_measurements.model_loaded_weight_memory
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||
n._memory_measurements = self._memory_measurements
|
||||
n.ckpt_name = self.ckpt_name
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
@ -122,11 +159,9 @@ class ModelPatcher(ModelManageable):
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
return hasattr(other, 'model') and self.model is other.model
|
||||
|
||||
def clone_has_same_weights(self, clone):
|
||||
def clone_has_same_weights(self, clone: "ModelPatcher"):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
@ -139,7 +174,8 @@ class ModelPatcher(ModelManageable):
|
||||
else:
|
||||
return True
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
def memory_required(self, input_shape) -> int:
|
||||
assert isinstance(self.model, BaseModel)
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
@ -281,16 +317,16 @@ class ModelPatcher(ModelManageable):
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None):
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
weight = utils.get_attr(self.model, key)
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
@ -319,31 +355,25 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self._current_device = device_to
|
||||
self.model_device = device_to
|
||||
self._memory_measurements.model_loaded_weight_memory = self.model_size()
|
||||
|
||||
return self.model
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
for n, m in self.model.named_modules():
|
||||
lowvram_weight = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = model_management.module_size(m)
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
if m.comfy_cast_weights:
|
||||
continue
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
@ -365,15 +395,39 @@ class ModelPatcher(ModelManageable):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
if m.comfy_cast_weights:
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if hasattr(m, "weight"):
|
||||
self.patch_weight_to_device(weight_key, device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to)
|
||||
m.to(device_to)
|
||||
mem_counter += model_management.module_size(m)
|
||||
param = list(m.parameters())
|
||||
if len(param) > 0:
|
||||
weight = param[0]
|
||||
if weight.device == device_to:
|
||||
continue
|
||||
|
||||
weight_to = None
|
||||
if full_load: # TODO
|
||||
weight_to = device_to
|
||||
self.patch_weight_to_device(weight_key, device_to=weight_to) # TODO: speed this up without OOM
|
||||
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
||||
m.to(device_to)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
|
||||
self.model_lowvram = True
|
||||
self.lowvram_patch_counter = patch_counter
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self._memory_measurements.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||
self._memory_measurements.model_lowvram = False
|
||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||
self._memory_measurements.model_loaded_weight_memory = mem_counter
|
||||
self.model_device = device_to
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
@ -548,31 +602,28 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
if unpatch_weights:
|
||||
if self.model_lowvram:
|
||||
if self._memory_measurements.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
m.weight_function = None
|
||||
m.bias_function = None
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
self.model_lowvram = False
|
||||
self.lowvram_patch_counter = 0
|
||||
self._memory_measurements.model_lowvram = False
|
||||
self._memory_measurements.lowvram_patch_counter = 0
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
utils.copy_to_param(self.model, k, self.backup[k])
|
||||
else:
|
||||
for k in keys:
|
||||
utils.set_attr_param(self.model, k, self.backup[k])
|
||||
for k in keys:
|
||||
bk = self.backup[k]
|
||||
if bk.inplace_update:
|
||||
utils.copy_to_param(self.model, k, bk.weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, k, bk.weight)
|
||||
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self._current_device = value = device_to
|
||||
self.model_device = device_to
|
||||
self._memory_measurements.model_loaded_weight_memory = 0
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
@ -580,9 +631,66 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
|
||||
for n, m in list(self.model.named_modules())[::-1]:
|
||||
if memory_to_free < memory_freed:
|
||||
break
|
||||
|
||||
shift_lowvram = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = model_management.module_size(m)
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if m.weight is not None and m.weight.device != device_to:
|
||||
for key in [weight_key, bias_key]:
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if bk.inplace_update:
|
||||
utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, bk.weight)
|
||||
self.backup.pop(key)
|
||||
|
||||
m.to(device_to)
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
self._memory_measurements.model_lowvram = True
|
||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||
self._memory_measurements.model_loaded_weight_memory -= memory_freed
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0):
|
||||
self.unpatch_model(unpatch_weights=False)
|
||||
self.patch_model(patch_weights=False)
|
||||
full_load = False
|
||||
if not self._memory_measurements.model_lowvram:
|
||||
return 0
|
||||
if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||
full_load = True
|
||||
current_used = self._memory_measurements.model_loaded_weight_memory
|
||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||
return self._memory_measurements.model_loaded_weight_memory - current_used
|
||||
|
||||
def current_loaded_device(self):
|
||||
return self.model_device
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return self._current_device
|
||||
return self.current_loaded_device()
|
||||
|
||||
def __str__(self):
|
||||
if self.ckpt_name is not None:
|
||||
|
||||
@ -823,14 +823,14 @@ class UNETLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
dtype = None
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
dtype = torch.float8_e4m3fn
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
dtype = torch.float8_e5m2
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
|
||||
model = sd.load_unet(unet_path, dtype=dtype)
|
||||
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
|
||||
@ -91,7 +91,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
|
||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
|
||||
for (duration, module_name, success, new_nodes) in sorted(timings):
|
||||
logging.info(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
||||
logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
||||
if raise_on_failure and len(exceptions) > 0:
|
||||
try:
|
||||
raise ExceptionGroup("Node import failed", exceptions)
|
||||
|
||||
83
comfy/sd.py
83
comfy/sd.py
@ -20,21 +20,19 @@ from . import model_sampling
|
||||
from . import sd1_clip
|
||||
from . import sdxl_clip
|
||||
from . import utils
|
||||
from .model_management import load_models_gpu
|
||||
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
from .text_encoders import hydit
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import aura_t5
|
||||
from .text_encoders import flux
|
||||
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .model_management import load_models_gpu
|
||||
from .t2i_adapter import adapter
|
||||
from .taesd import taesd
|
||||
from .text_encoders import aura_t5
|
||||
from .text_encoders import flux
|
||||
from .text_encoders import hydit
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
@ -68,7 +66,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target: CLIPTarget=None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None=None):
|
||||
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = dict()
|
||||
if no_init:
|
||||
@ -79,9 +77,9 @@ class CLIP:
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
params['dtype'] = dtype
|
||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||
if "textmodel_json_config" not in params and textmodel_json_config is not None:
|
||||
params['textmodel_json_config'] = textmodel_json_config
|
||||
|
||||
@ -90,11 +88,16 @@ class CLIP:
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
if params['device'] != offload_device:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@ -476,7 +479,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target.clip = sd3_clip.SD3ClipModel
|
||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config)
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
parameters += utils.calculate_parameters(c)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@ -523,15 +530,21 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, ckpt_path=ckpt_path)
|
||||
if out is None:
|
||||
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, ckpt_path: str | None = None):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
model = None
|
||||
_model_patcher = None
|
||||
clip_target = None
|
||||
inital_load_device = None
|
||||
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
@ -540,13 +553,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
return None
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if weight_dtype is not None:
|
||||
unet_weight_dtype.append(weight_dtype)
|
||||
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("weight_dtype", None)
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
@ -570,7 +588,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
||||
parameters = utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
@ -589,14 +608,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
|
||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), ckpt_name=os.path.basename(ckpt_path))
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
load_models_gpu([_model_patcher])
|
||||
model_management.load_models_gpu([_model_patcher], force_full_load=True)
|
||||
|
||||
return (_model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options: dict = None): # load unet in diffusers or regular format
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
# Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
@ -638,6 +660,7 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
@ -647,15 +670,27 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
|
||||
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
def load_diffusion_model(unet_path, model_options: dict = None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd, dtype=dtype)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None):
|
||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
|
||||
@ -16,7 +16,7 @@ except ImportError:
|
||||
from typing import Tuple, Sequence, TypeVar, Callable
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from . import clip_model
|
||||
from . import model_management
|
||||
@ -66,7 +66,7 @@ class ClipTokenWeightEncoder:
|
||||
|
||||
output = []
|
||||
for k in range(0, sections):
|
||||
z = out[k:k+1]
|
||||
z = out[k:k + 1]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
@ -112,7 +112,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
|
||||
|
||||
|
||||
self.operations = ops.manual_cast
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
@ -389,6 +388,18 @@ def expand_directory_list(directories):
|
||||
return list(dirs)
|
||||
|
||||
|
||||
def bundled_embed(embed, prefix, suffix): # bundled embedding in lora format
|
||||
i = 0
|
||||
out_list = []
|
||||
for k in embed:
|
||||
if k.startswith(prefix) and k.endswith(suffix):
|
||||
out_list.append(embed[k])
|
||||
if len(out_list) == 0:
|
||||
return None
|
||||
|
||||
return torch.cat(out_list, dim=0)
|
||||
|
||||
|
||||
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
@ -455,8 +466,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
elif embed_key is not None and embed_key in embed:
|
||||
embed_out = embed[embed_key]
|
||||
else:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
|
||||
if embed_out is None:
|
||||
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
|
||||
if embed_out is None:
|
||||
values = embed.values()
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
|
||||
@ -631,6 +646,7 @@ class SDTokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
|
||||
|
||||
|
||||
@ -664,6 +680,7 @@ class SD1Tokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
@ -640,9 +640,9 @@ class Flux(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 2.6
|
||||
memory_usage_factor = 2.8
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from . import model_base
|
||||
from . import utils
|
||||
@ -30,6 +48,7 @@ class BASE:
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
|
||||
@ -1,3 +1,20 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
@ -472,8 +489,33 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
k = "{}.attn.".format(prefix_from)
|
||||
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
@ -489,15 +531,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
|
||||
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||
|
||||
block_map = {#TODO
|
||||
block_map = {
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||
|
||||
MAP_BASIC = { #TODO
|
||||
MAP_BASIC = {
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
("img_in.bias", "x_embedder.bias"),
|
||||
("img_in.weight", "x_embedder.weight"),
|
||||
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||
("txt_in.bias", "context_embedder.bias"),
|
||||
("txt_in.weight", "context_embedder.weight"),
|
||||
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||
}
|
||||
|
||||
for k in MAP_BASIC:
|
||||
@ -859,6 +927,8 @@ def seed_for_block(seed):
|
||||
numpy_rng_state = np.random.get_state()
|
||||
if torch.cuda.is_available():
|
||||
cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||
else:
|
||||
cuda_rng_state = None
|
||||
|
||||
# Set the new seed
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@ -19,6 +19,7 @@ class CLIPTextEncodeHunyuanDiT:
|
||||
cond = output.pop("cond")
|
||||
return ([[cond, output]], )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
}
|
||||
|
||||
@ -11,9 +11,10 @@ from typing import Any, Dict, Optional, List, Callable, Union
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
||||
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel
|
||||
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel, AutoModelForCausalLM
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||
from comfy.language.language_types import ProcessorResult
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
@ -28,9 +29,9 @@ _AUTO_CHAT_TEMPLATE = "default"
|
||||
try:
|
||||
from llava import model
|
||||
|
||||
logging.info("Additional LLaVA models are now supported")
|
||||
logging.debug("Additional LLaVA models are now supported")
|
||||
except ImportError as exc:
|
||||
logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
||||
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
||||
|
||||
# aka kwargs type
|
||||
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
||||
@ -129,7 +130,7 @@ class TransformersGenerationConfig(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",)
|
||||
"model": ("MODEL", {})
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,13 +248,22 @@ class TransformersLoader(CustomNode):
|
||||
**hub_kwargs
|
||||
}
|
||||
|
||||
try:
|
||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
|
||||
except Exception as exc_info:
|
||||
# not yet supported by automodel
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
|
||||
# try:
|
||||
# import flash_attn
|
||||
# from_pretrained_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
# except ImportError:
|
||||
# logging.debug("install flash_attn for improved performance using language nodes")
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
|
||||
|
||||
if config_dict["model_type"] == "llava_next":
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
|
||||
else:
|
||||
try:
|
||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
|
||||
except Exception:
|
||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs)
|
||||
|
||||
try:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
|
||||
@ -265,6 +275,10 @@ class TransformersLoader(CustomNode):
|
||||
processor = None
|
||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
|
||||
|
||||
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
logging.debug("enabled xformers memory efficient attention")
|
||||
|
||||
model_managed = TransformersManagedModel(
|
||||
repo_id=ckpt_name,
|
||||
model=model,
|
||||
|
||||
@ -263,6 +263,7 @@ class CLIPSave:
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["format"] = "pt"
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
|
||||
@ -161,7 +161,7 @@ class BooleanRequestParameter(CustomNode):
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
|
||||
@ -108,3 +108,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||
"ControlNetApplySD3": ControlNetApplySD3,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@ try:
|
||||
from spandrel import MAIN_REGISTRY
|
||||
|
||||
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
||||
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||
logging.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -26,20 +26,16 @@ class UpscaleModelManageable(ModelManageable):
|
||||
self.ckpt_name = ckpt_name
|
||||
self.model_descriptor = model_descriptor
|
||||
self.model = model_descriptor.model
|
||||
self.load_device = model_management.unet_offload_device()
|
||||
self.load_device = model_management.get_torch_device()
|
||||
self.offload_device = model_management.unet_offload_device()
|
||||
self._current_device = self.offload_device
|
||||
self._lowvram_patch_counter = 0
|
||||
|
||||
# Private properties for image sizes and channels
|
||||
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
|
||||
self._input_size = (1, 512, 512)
|
||||
self._input_channels = model_descriptor.input_channels
|
||||
self._output_channels = model_descriptor.output_channels
|
||||
self.tile = 512
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return self._current_device
|
||||
return self.model_descriptor.device
|
||||
|
||||
@property
|
||||
def input_size(self) -> tuple[int, int, int]:
|
||||
@ -65,21 +61,14 @@ class UpscaleModelManageable(ModelManageable):
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
return isinstance(other, UpscaleModelManageable) and self.model is other.model
|
||||
|
||||
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
||||
def clone_has_same_weights(self, clone) -> bool:
|
||||
return self.is_clone(clone)
|
||||
|
||||
def model_size(self) -> int:
|
||||
# Calculate the size of the model parameters
|
||||
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
|
||||
|
||||
# Get the byte size of the model's dtype
|
||||
dtype_size = torch.finfo(self.model_dtype()).bits // 8
|
||||
|
||||
# Calculate the memory required for input and output images
|
||||
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
|
||||
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
|
||||
|
||||
# Add some extra memory for processing
|
||||
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
|
||||
|
||||
return model_params_size + input_size + output_size + extra_memory
|
||||
@ -95,30 +84,16 @@ class UpscaleModelManageable(ModelManageable):
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
self.model.to(device=device_to)
|
||||
self._current_device = device_to
|
||||
self._lowvram_patch_counter += 1
|
||||
return self.model
|
||||
|
||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||
if patch_weights:
|
||||
self.model.to(device=device_to)
|
||||
self._current_device = device_to
|
||||
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||
self.model.to(device=device_to)
|
||||
return self.model
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
if unpatch_weights:
|
||||
self.model.to(device=offload_device)
|
||||
self._current_device = offload_device
|
||||
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
self.model.to(device=offload_device)
|
||||
return self.model
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
return self._lowvram_patch_counter
|
||||
|
||||
@lowvram_patch_counter.setter
|
||||
def lowvram_patch_counter(self, value: int):
|
||||
self._lowvram_patch_counter = value
|
||||
|
||||
def __str__(self):
|
||||
if self.ckpt_name is not None:
|
||||
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||
|
||||
@ -46,7 +46,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_known_repos(tmp_path_factory):
|
||||
async def test_known_repos(tmp_path_factory):
|
||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||
|
||||
|
||||
@ -39,12 +39,17 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
||||
workflow = json.loads(workflow_file.read_text())
|
||||
|
||||
prompt = Prompt.validate(workflow)
|
||||
# todo: add all the models we want to test a bit more elegantly
|
||||
# todo: add all the models we want to test a bit m2ore elegantly
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
|
||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
|
||||
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
||||
assert outputs[save_image_node_id]["audio"][0]["filename"] is not None
|
||||
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
||||
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
|
||||
elif any(v.class_type == "PreviewString" for v in prompt.values()):
|
||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
|
||||
output_str = outputs[save_image_node_id]["string"][0]
|
||||
assert output_str is not None
|
||||
assert len(output_str) > 0
|
||||
|
||||
53
tests/inference/workflows/image-upscale-with-model-0.json
Normal file
53
tests/inference/workflows/image-upscale-with-model-0.json
Normal file
@ -0,0 +1,53 @@
|
||||
{
|
||||
"17": {
|
||||
"inputs": {
|
||||
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||
"name": "",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "ImageRequestParameter",
|
||||
"_meta": {
|
||||
"title": "ImageRequestParameter"
|
||||
}
|
||||
},
|
||||
"19": {
|
||||
"inputs": {
|
||||
"model_name": "RealESRGAN_x4plus.pth"
|
||||
},
|
||||
"class_type": "UpscaleModelLoader",
|
||||
"_meta": {
|
||||
"title": "Load Upscale Model"
|
||||
}
|
||||
},
|
||||
"20": {
|
||||
"inputs": {
|
||||
"upscale_model": [
|
||||
"19",
|
||||
0
|
||||
],
|
||||
"image": [
|
||||
"17",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ImageUpscaleWithModel",
|
||||
"_meta": {
|
||||
"title": "Upscale Image (using Model)"
|
||||
}
|
||||
},
|
||||
"21": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"20",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
91
tests/inference/workflows/llava-0.json
Normal file
91
tests/inference/workflows/llava-0.json
Normal file
@ -0,0 +1,91 @@
|
||||
{
|
||||
"1": {
|
||||
"inputs": {
|
||||
"ckpt_name": "llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"subfolder": ""
|
||||
},
|
||||
"class_type": "TransformersLoader",
|
||||
"_meta": {
|
||||
"title": "TransformersLoader"
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"max_new_tokens": 512,
|
||||
"repetition_penalty": 0,
|
||||
"seed": 2013744903,
|
||||
"use_cache": true,
|
||||
"__tokens": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. </s>",
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
],
|
||||
"tokens": [
|
||||
"4",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "TransformersGenerate",
|
||||
"_meta": {
|
||||
"title": "TransformersGenerate"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"prompt": "Who is this?",
|
||||
"chat_template": "llava-v1.6-mistral-7b-hf",
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
],
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "OneShotInstructTokenize",
|
||||
"_meta": {
|
||||
"title": "OneShotInstructTokenize"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"value": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"output": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. "
|
||||
},
|
||||
"class_type": "PreviewString",
|
||||
"_meta": {
|
||||
"title": "PreviewString"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||
"name": "",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "ImageRequestParameter",
|
||||
"_meta": {
|
||||
"title": "ImageRequestParameter"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"upscale_method": "nearest-exact",
|
||||
"megapixels": 1,
|
||||
"image": [
|
||||
"6",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ImageScaleToTotalPixels",
|
||||
"_meta": {
|
||||
"title": "ImageScaleToTotalPixels"
|
||||
}
|
||||
}
|
||||
}
|
||||
60
tests/inference/workflows/phi-3-0.json
Normal file
60
tests/inference/workflows/phi-3-0.json
Normal file
@ -0,0 +1,60 @@
|
||||
{
|
||||
"1": {
|
||||
"inputs": {
|
||||
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct",
|
||||
"subfolder": ""
|
||||
},
|
||||
"class_type": "TransformersLoader",
|
||||
"_meta": {
|
||||
"title": "TransformersLoader"
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"max_new_tokens": 512,
|
||||
"repetition_penalty": 0,
|
||||
"seed": 2514389986,
|
||||
"use_cache": true,
|
||||
"__tokens": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple.<|end|>",
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
],
|
||||
"tokens": [
|
||||
"4",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "TransformersGenerate",
|
||||
"_meta": {
|
||||
"title": "TransformersGenerate"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"prompt": "What comes after apple?",
|
||||
"chat_template": "phi-3",
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "OneShotInstructTokenize",
|
||||
"_meta": {
|
||||
"title": "OneShotInstructTokenize"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"value": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"output": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple."
|
||||
},
|
||||
"class_type": "PreviewString",
|
||||
"_meta": {
|
||||
"title": "PreviewString"
|
||||
}
|
||||
}
|
||||
}
|
||||
223
tests/inference/workflows/sdxl-union-controlnet-0.json
Normal file
223
tests/inference/workflows/sdxl-union-controlnet-0.json
Normal file
@ -0,0 +1,223 @@
|
||||
{
|
||||
"1": {
|
||||
"inputs": {
|
||||
"ckpt_name": "sd_xl_base_1.0.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"2": {
|
||||
"inputs": {
|
||||
"strength": 0.5,
|
||||
"start_percent": 0,
|
||||
"end_percent": 1,
|
||||
"positive": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"control_net": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"image": [
|
||||
"9",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ControlNetApplyAdvanced",
|
||||
"_meta": {
|
||||
"title": "Apply ControlNet (Advanced)"
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"text": "a girl with blue hair",
|
||||
"clip": [
|
||||
"1",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"add_noise": true,
|
||||
"noise_seed": 969970429360105,
|
||||
"cfg": 8,
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"2",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"2",
|
||||
1
|
||||
],
|
||||
"sampler": [
|
||||
"13",
|
||||
0
|
||||
],
|
||||
"sigmas": [
|
||||
"11",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"12",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SamplerCustom",
|
||||
"_meta": {
|
||||
"title": "SamplerCustom"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "",
|
||||
"clip": [
|
||||
"1",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"type": "canny/lineart/anime_lineart/mlsd",
|
||||
"control_net": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SetUnionControlNetType",
|
||||
"_meta": {
|
||||
"title": "SetUnionControlNetType"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
|
||||
},
|
||||
"class_type": "ControlNetLoader",
|
||||
"_meta": {
|
||||
"title": "Load ControlNet Model"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"low_threshold": 0.4,
|
||||
"high_threshold": 0.8,
|
||||
"image": [
|
||||
"18",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "Canny",
|
||||
"_meta": {
|
||||
"title": "Canny"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"inputs": {
|
||||
"model_type": "SDXL",
|
||||
"steps": 25,
|
||||
"denoise": 1
|
||||
},
|
||||
"class_type": "AlignYourStepsScheduler",
|
||||
"_meta": {
|
||||
"title": "AlignYourStepsScheduler"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"eta": 1,
|
||||
"s_noise": 1
|
||||
},
|
||||
"class_type": "SamplerEulerAncestral",
|
||||
"_meta": {
|
||||
"title": "SamplerEulerAncestral"
|
||||
}
|
||||
},
|
||||
"14": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"5",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"1",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"15": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"14",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||
"name": "",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "ImageRequestParameter",
|
||||
"_meta": {
|
||||
"title": "ImageRequestParameter"
|
||||
}
|
||||
},
|
||||
"18": {
|
||||
"inputs": {
|
||||
"upscale_method": "lanczos",
|
||||
"megapixels": 1,
|
||||
"image": [
|
||||
"17",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ImageScaleToTotalPixels",
|
||||
"_meta": {
|
||||
"title": "ImageScaleToTotalPixels"
|
||||
}
|
||||
}
|
||||
}
|
||||
302
tests/inference/workflows/sdxl-union-controlnet-1.json
Normal file
302
tests/inference/workflows/sdxl-union-controlnet-1.json
Normal file
@ -0,0 +1,302 @@
|
||||
{
|
||||
"1": {
|
||||
"inputs": {
|
||||
"ckpt_name": "sd_xl_base_1.0.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"2": {
|
||||
"inputs": {
|
||||
"strength": 0.5,
|
||||
"start_percent": 0,
|
||||
"end_percent": 1,
|
||||
"positive": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"control_net": [
|
||||
"28",
|
||||
0
|
||||
],
|
||||
"image": [
|
||||
"9",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ControlNetApplyAdvanced",
|
||||
"_meta": {
|
||||
"title": "Apply ControlNet (Advanced)"
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"text": [
|
||||
"26",
|
||||
0
|
||||
],
|
||||
"clip": [
|
||||
"1",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"add_noise": [
|
||||
"23",
|
||||
0
|
||||
],
|
||||
"noise_seed": [
|
||||
"20",
|
||||
0
|
||||
],
|
||||
"cfg": [
|
||||
"19",
|
||||
0
|
||||
],
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"2",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"2",
|
||||
1
|
||||
],
|
||||
"sampler": [
|
||||
"24",
|
||||
0
|
||||
],
|
||||
"sigmas": [
|
||||
"11",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"12",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SamplerCustom",
|
||||
"_meta": {
|
||||
"title": "SamplerCustom"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "",
|
||||
"clip": [
|
||||
"1",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"low_threshold": 0.4,
|
||||
"high_threshold": 0.8,
|
||||
"image": [
|
||||
"18",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "Canny",
|
||||
"_meta": {
|
||||
"title": "Canny"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"inputs": {
|
||||
"model_type": "SDXL",
|
||||
"steps": 25,
|
||||
"denoise": 1
|
||||
},
|
||||
"class_type": "AlignYourStepsScheduler",
|
||||
"_meta": {
|
||||
"title": "AlignYourStepsScheduler"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"14": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"5",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"1",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"15": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"14",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||
"name": "",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "ImageRequestParameter",
|
||||
"_meta": {
|
||||
"title": "ImageRequestParameter"
|
||||
}
|
||||
},
|
||||
"18": {
|
||||
"inputs": {
|
||||
"upscale_method": "lanczos",
|
||||
"megapixels": 1,
|
||||
"image": [
|
||||
"17",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ImageScaleToTotalPixels",
|
||||
"_meta": {
|
||||
"title": "ImageScaleToTotalPixels"
|
||||
}
|
||||
},
|
||||
"19": {
|
||||
"inputs": {
|
||||
"value": 8,
|
||||
"name": "cfg",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "FloatRequestParameter",
|
||||
"_meta": {
|
||||
"title": "FloatRequestParameter"
|
||||
}
|
||||
},
|
||||
"20": {
|
||||
"inputs": {
|
||||
"value": 0,
|
||||
"name": "seed",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "IntRequestParameter",
|
||||
"_meta": {
|
||||
"title": "IntRequestParameter"
|
||||
}
|
||||
},
|
||||
"23": {
|
||||
"inputs": {
|
||||
"value": true,
|
||||
"name": "add_noise",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "BooleanRequestParameter",
|
||||
"_meta": {
|
||||
"title": "BooleanRequestParameter"
|
||||
}
|
||||
},
|
||||
"24": {
|
||||
"inputs": {
|
||||
"sampler_name": [
|
||||
"25",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerSelect",
|
||||
"_meta": {
|
||||
"title": "KSamplerSelect"
|
||||
}
|
||||
},
|
||||
"25": {
|
||||
"inputs": {
|
||||
"value": "euler",
|
||||
"name": "sampler_name",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "StringEnumRequestParameter",
|
||||
"_meta": {
|
||||
"title": "StringEnumRequestParameter"
|
||||
}
|
||||
},
|
||||
"26": {
|
||||
"inputs": {
|
||||
"value": "a girl with blue hair",
|
||||
"name": "",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "StringRequestParameter",
|
||||
"_meta": {
|
||||
"title": "StringRequestParameter"
|
||||
}
|
||||
},
|
||||
"27": {
|
||||
"inputs": {
|
||||
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
|
||||
},
|
||||
"class_type": "ControlNetLoader",
|
||||
"_meta": {
|
||||
"title": "Load ControlNet Model"
|
||||
}
|
||||
},
|
||||
"28": {
|
||||
"inputs": {
|
||||
"type": "canny/lineart/anime_lineart/mlsd",
|
||||
"control_net": [
|
||||
"27",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SetUnionControlNetType",
|
||||
"_meta": {
|
||||
"title": "SetUnionControlNetType"
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user