Merge commit '39fb74c5bd13a1dccf4d7293a2f7a755d9f43cbd' of github.com:comfyanonymous/ComfyUI

- Improvements to tests
 - Fixes model management
 - Fixes issues with language nodes
This commit is contained in:
doctorpangloss 2024-08-13 20:08:56 -07:00
commit 0549f35e85
46 changed files with 2304 additions and 425 deletions

View File

@ -0,0 +1,53 @@
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Pull Request CI Workflow Runs
on:
pull_request_target:
types: [labeled]
jobs:
pr-test-stable:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
use_prior_commit: 'true'
comment:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
})

95
.github/workflows/test-ci.yml vendored Normal file
View File

@ -0,0 +1,95 @@
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Full Comfy CI Workflow Runs
on:
push:
branches:
- master
paths-ignore:
- 'app/**'
- 'input/**'
- 'output/**'
- 'notebooks/**'
- 'script_examples/**'
- '.github/**'
- 'web/**'
workflow_dispatch:
jobs:
test-stable:
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-win-nightly:
strategy:
fail-fast: true
matrix:
os: [windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:
fail-fast: false
matrix:
os: [macos, linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}

3
.gitignore vendored
View File

@ -176,4 +176,5 @@ cython_debug/
/tests-ui/data/object_info.json /tests-ui/data/object_info.json
/user/ /user/
*.log *.log
web_custom_versions/ web_custom_versions/
.DS_Store

View File

@ -14,9 +14,9 @@ from typing import TypedDict
import requests import requests
from typing_extensions import NotRequired from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING from ..cli_args import DEFAULT_VERSION_STRING
from comfy.cmd.folder_paths import add_model_folder_path from ..cmd.folder_paths import add_model_folder_path
from comfy.component_model.files import get_package_as_path from ..component_model.files import get_package_as_path
REQUEST_TIMEOUT = 10 # seconds REQUEST_TIMEOUT = 10 # seconds

View File

@ -10,7 +10,6 @@ import time
import traceback import traceback
import typing import typing
from os import PathLike from os import PathLike
from pathlib import PurePath
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import lazy_object_proxy import lazy_object_proxy
@ -34,9 +33,11 @@ from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOption
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place. # various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy.graph_utils import is_link, GraphBuilder # order matters
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
class IsChangedCache: class IsChangedCache:
@ -446,6 +447,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
"traceback": traceback.format_tb(tb), "traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted "current_inputs": input_data_formatted
} }
if isinstance(ex, model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
model_management.unload_all_models()
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex) return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id) executed.add(unique_id)

View File

@ -55,7 +55,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) logging.debug("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags() flags = q.get_flags()
free_memory = flags.get("free_memory", False) free_memory = flags.get("free_memory", False)

View File

@ -838,7 +838,7 @@ class PromptServer(ExecutorToClientProgress):
self.port = port self.port = port
if verbose: if verbose:
logging.info("Starting server\n") logging.info("Starting server")
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port)) logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)

View File

@ -1,4 +1,24 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
from enum import Enum
import math import math
import os import os
import logging import logging
@ -12,7 +32,8 @@ from . import latent_formats
from .cldm import cldm, mmdit from .cldm import cldm, mmdit
from .t2i_adapter import adapter from .t2i_adapter import adapter
from .ldm.cascade import controlnet from .ldm import hydit, flux
from .ldm.cascade import controlnet as cascade_controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -33,6 +54,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else: else:
return torch.cat([tensor] * batched_number, dim=0) return torch.cat([tensor] * batched_number, dim=0)
class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlBase: class ControlBase:
def __init__(self, device=None): def __init__(self, device=None):
self.cond_hint_original = None self.cond_hint_original = None
@ -51,6 +76,8 @@ class ControlBase:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -93,6 +120,8 @@ class ControlBase:
c.latent_format = self.latent_format c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy() c.extra_args = self.extra_args.copy()
c.vae = self.vae c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -113,7 +142,10 @@ class ControlBase:
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x) applied_to.add(x)
x *= self.strength if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype: if x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
@ -142,7 +174,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, ckpt_name: str = None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@ -154,6 +186,8 @@ class ControlNet(ControlBase):
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -191,13 +225,16 @@ class ControlNet(ControlBase):
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond.get('crossattn_controlnet', cond['c_crossattn']) context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None) extra = self.extra_args.copy()
if y is not None: for c in self.extra_conds:
y = y.to(dtype) temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
timestep = self.model_sampling_current.timestep(t) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype) return self.control_merge(control, control_prev, output_dtype)
def copy(self): def copy(self):
@ -286,6 +323,7 @@ class ControlLora(ControlNet):
ControlBase.__init__(self, device) ControlBase.__init__(self, device)
self.control_weights = control_weights self.control_weights = control_weights
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
def pre_run(self, model, percent_to_timestep_function): def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function) super().pre_run(model, percent_to_timestep_function)
@ -338,12 +376,8 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet_mmdit(sd): def controlnet_config(sd):
new_sd = model_detection.convert_diffusers_mmdit(sd, "") model_config = model_detection.model_config_from_unet(sd, "", True)
model_config = model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
supported_inference_dtypes = model_config.supported_inference_dtypes supported_inference_dtypes = model_config.supported_inference_dtypes
@ -356,14 +390,27 @@ def load_controlnet_mmdit(sd):
else: else:
operations = ops.disable_weight_init operations = ops.disable_weight_init
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config) return model_config, operations, load_device, unet_dtype, manual_cast_dtype
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
if len(missing) > 0: if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing)) logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0: if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd):
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.SD3() latent_format = latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness latent_format.shift_factor = 0 #SD3 controlnet weirdness
@ -371,8 +418,30 @@ def load_controlnet_mmdit(sd):
return control return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data) return ControlLora(controlnet_data)
@ -430,7 +499,10 @@ def load_controlnet(ckpt_path, model=None):
logging.warning("leftover keys: {}".format(leftover_keys)) logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
return load_controlnet_mmdit(controlnet_data) if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
else:
return load_controlnet_mmdit(controlnet_data)
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
@ -590,11 +662,11 @@ def load_t2i_adapter(t2i_data):
xl = True xl = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
elif "backbone.0.0.weight" in keys: elif "backbone.0.0.weight" in keys:
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 32 compression_ratio = 32
upscale_algorithm = 'bilinear' upscale_algorithm = 'bilinear'
elif "backbone.10.blocks.0.weight" in keys: elif "backbone.10.blocks.0.weight" in keys:
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 1 compression_ratio = 1
upscale_algorithm = 'nearest-exact' upscale_algorithm = 'nearest-exact'
else: else:

View File

@ -28,7 +28,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
unet = None unet = None
if unet_path is not None: if unet_path is not None:
unet = sd.load_unet(unet_path) unet = sd.load_diffusion_model(unet_path)
clip = None clip = None
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"]) textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])

View File

@ -8,7 +8,7 @@ from tqdm.auto import trange, tqdm
from . import utils from . import utils
from . import deis from . import deis
import comfy.model_patcher from .. import model_patcher
def append_zero(x): def append_zero(x):
return torch.cat([x, x.new_zeros([1])]) return torch.cat([x, x.new_zeros([1])])
@ -1032,7 +1032,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
return args["denoised"] return args["denoised"]
model_options = extra_args.get("model_options", {}).copy() model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
@ -1058,7 +1058,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
return args["denoised"] return args["denoised"]
model_options = extra_args.get("model_options", {}).copy() model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):

View File

@ -77,19 +77,9 @@ class TransformersManagedModel(ModelManageable):
return self.model.config.to_dict() return self.model.config.to_dict()
@property
def lowvram_patch_counter(self): def lowvram_patch_counter(self):
return 0 return 0
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
warnings.warn("Not supported")
pass
load_device: torch.device
offload_device: torch.device
model: PreTrainedModel
@property @property
def current_device(self) -> torch.device: def current_device(self) -> torch.device:
return self.model.device return self.model.device
@ -127,12 +117,10 @@ class TransformersManagedModel(ModelManageable):
warnings.warn("Transformers models do not currently support adapters like LoRAs") warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to) return self.model.to(device=device_to)
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module: def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to) return self.model.to(device=device_to)
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=offload_device) return self.model.to(device=offload_device)
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel: def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
@ -177,7 +165,7 @@ class TransformersManagedModel(ModelManageable):
self.processor.to(device=self.load_device) self.processor.to(device=self.load_device)
assert "<image>" in prompt.lower(), "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor" assert "<image>" in prompt.lower(), "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor"
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt") batch_feature: BatchFeature = self.processor([prompt], images=images.unbind(), padding=True, return_tensors="pt")
if hasattr(self.processor, "to"): if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device) self.processor.to(device=self.offload_device)
assert "input_ids" in batch_feature assert "input_ids" in batch_feature

View File

@ -0,0 +1,104 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
return {"output": (controlnet_block_res_samples * 10)[:19]}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)

View File

@ -2,12 +2,12 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from .math import attention, rope from .math import attention, rope
from ... import ops from ... import ops
class EmbedND(nn.Module): class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list): def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__() super().__init__()
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
""" """
t = time_factor * t t = time_factor * t
half = dim // 2 half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
t.device
)
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@ -48,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
embedding = embedding.to(t) embedding = embedding.to(t)
return embedding return embedding
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@ -94,14 +91,6 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass @dataclass
class ModulationOut: class ModulationOut:
@ -163,22 +152,21 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
q = torch.cat((txt_q, img_q), dim=2) attn = attention(torch.cat((txt_q, img_q), dim=2),
k = torch.cat((txt_k, img_k), dim=2) torch.cat((txt_k, img_k), dim=2),
v = torch.cat((txt_v, img_v), dim=2) torch.cat((txt_v, img_v), dim=2), pe=pe)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks # calculate the img bloks
@ -186,8 +174,12 @@ class DoubleStreamBlock(nn.Module):
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks # calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
return img, txt return img, txt
@ -232,14 +224,17 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
return x
class LastLayer(nn.Module): class LastLayer(nn.Module):

View File

@ -38,7 +38,7 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
""" """
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
params = FluxParams(**kwargs) params = FluxParams(**kwargs)
@ -83,7 +83,8 @@ class Flux(nn.Module):
] ]
) )
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig( def forward_orig(
self, self,
@ -94,6 +95,7 @@ class Flux(nn.Module):
timesteps: Tensor, timesteps: Tensor,
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
control=None,
) -> Tensor: ) -> Tensor:
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.") raise ValueError("Input img and txt tensors must have 3 dimensions.")
@ -112,8 +114,15 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
for block in self.double_blocks: for i in range(len(self.double_blocks)):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe) img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
if control is not None: #Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
for block in self.single_blocks: for block in self.single_blocks:
@ -123,7 +132,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img return img
def forward(self, x, timestep, context, y, guidance, **kwargs): def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
bs, c, h, w = x.shape bs, c, h, w = x.shape
patch_size = 2 patch_size = 2
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@ -138,5 +147,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
def rotate_half(x): def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3) return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
@ -78,10 +78,9 @@ def apply_rotary_emb(
xk_out = None xk_out = None
if isinstance(freqs_cis, tuple): if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device) xq_out = (xq * cos + rotate_half(xq) * sin)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None: if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) xk_out = (xk * cos + rotate_half(xk) * sin)
else: else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]

View File

@ -0,0 +1,321 @@
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder,
PatchEmbed,
RMSNorm,
)
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
class HunYuanControlNet(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
def __init__(
self,
input_size: tuple = 128,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1408,
depth: int = 40,
num_heads: int = 16,
mlp_ratio: float = 4.3637,
text_states_dim=1024,
text_states_dim_t5=2048,
text_len=77,
text_len_t5=256,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
size_cond=False,
use_style_cond=False,
learn_sigma=True,
norm="layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
self.latent_format = comfy.latent_formats.SDXL
self.mlp_t5 = nn.Sequential(
nn.Linear(
self.text_states_dim_t5,
self.text_states_dim_t5 * 4,
bias=True,
dtype=dtype,
device=device,
),
nn.SiLU(),
nn.Linear(
self.text_states_dim_t5 * 4,
self.text_states_dim,
bias=True,
dtype=dtype,
device=device,
),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(
self.text_len + self.text_len_t5,
self.text_states_dim,
dtype=dtype,
device=device,
)
)
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(
self.text_len_t5,
self.text_states_dim_t5,
num_heads=8,
output_dim=pooler_out_dim,
dtype=dtype,
device=device,
operations=operations,
)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(
1, hidden_size, dtype=dtype, device=device
)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
hidden_size,
dtype=dtype,
device=device,
operations=operations,
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations
)
self.extra_embedder = nn.Sequential(
operations.Linear(
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
),
nn.SiLU(),
operations.Linear(
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunYuanDiTBlock(
hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=False,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(19)
]
)
# Input zero linear for the first block
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
operations.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device
)
for _ in range(len(self.blocks))
]
)
def forward(
self,
x,
hint,
timesteps,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
return_dict=False,
**kwarg,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
condition = hint
if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
text_states = context # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:, -self.text_len :] = torch.where(
text_states_mask[:, -self.text_len :].unsqueeze(2),
text_states[:, -self.text_len :],
padding[: self.text_len],
)
text_states_t5[:, -self.text_len_t5 :] = torch.where(
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
text_states_t5[:, -self.text_len_t5 :],
padding[self.text_len :],
)
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# _, _, oh, ow = x.shape
# th, tw = oh // self.patch_size, ow // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(
x, self.patch_size, self.hidden_size // self.num_heads
) # (cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(timesteps, dtype=self.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
# if image_meta_size is not None:
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
# if image_meta_size.dtype != self.dtype:
# image_meta_size = image_meta_size.half()
# image_meta_size = image_meta_size.view(-1, 6 * 256)
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if style is not None:
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Deal with Condition =========================
condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
return {"output": controls}

View File

@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
sub_args = [start, stop, (th, tw)] sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args) rope = get_2d_rotary_pos_embed(head_size, *sub_args)
rope = (rope[0].to(x), rope[1].to(x))
return rope return rope
@ -91,6 +92,8 @@ class HunYuanDiTBlock(nn.Module):
# Long Skip Connection # Long Skip Connection
if self.skip_linear is not None: if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1) cat = torch.cat([x, skip], dim=-1)
if cat.dtype != x.dtype:
cat = cat.to(x.dtype)
cat = self.skip_norm(cat) cat = self.skip_norm(cat)
x = self.skip_linear(cat) x = self.skip_linear(cat)
@ -362,6 +365,8 @@ class HunYuanDiT(nn.Module):
c = t + self.extra_embedder(extra_vec) # [B, D] c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None controls = None
if control:
controls = control.get("output", None)
# ========================= Forward pass through HunYuanDiT blocks ========================= # ========================= Forward pass through HunYuanDiT blocks =========================
skips = [] skips = []
for layer, block in enumerate(self.blocks): for layer, block in enumerate(self.blocks):

View File

@ -411,17 +411,17 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
logging.info("Using xformers cross attention") logging.debug("Using xformers cross attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch cross attention") logging.debug("Using pytorch cross attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch
else: else:
if args.use_split_cross_attention: if args.use_split_cross_attention:
logging.info("Using split optimization for cross attention") logging.debug("Using split optimization for cross attention")
optimized_attention = attention_split optimized_attention = attention_split
else: else:
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") logging.debug("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention

View File

@ -268,13 +268,13 @@ class AttnBlock(nn.Module):
padding=0) padding=0)
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE") logging.debug("Using xformers attention in VAE")
self.optimized_attention = xformers_attention self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE") logging.debug("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention self.optimized_attention = pytorch_attention
else: else:
logging.info("Using split attention in VAE") logging.debug("Using split attention in VAE")
self.optimized_attention = normal_attention self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):

View File

@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import logging import logging
from . import utils from . import utils
from . import model_base from . import model_base
@ -216,11 +234,17 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k key_map[lora_key] = k
for k in sdk: #OneTrainer SD3 lora for k in sdk:
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): if k.endswith(".weight"):
l_key = k[len("t5xxl.transformer."):-len(".weight")] if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) l_key = k[len("t5xxl.transformer."):-len(".weight")]
key_map[lora_key] = k lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight" k = "clip_g.transformer.text_projection.weight"
if k in sdk: if k in sdk:
@ -243,6 +267,7 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config) diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys: for k in diffusers_keys:

View File

@ -1,7 +1,26 @@
import logging """
from enum import Enum This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import logging
import math import math
from enum import Enum
from typing import TypeVar, Type
import torch import torch
from . import conds from . import conds
@ -11,15 +30,16 @@ from . import ops
from . import utils from . import utils
from .ldm.audio.dit import AudioDiffusionTransformer from .ldm.audio.dit import AudioDiffusionTransformer
from .ldm.audio.embedders import NumberConditioner from .ldm.audio.embedders import NumberConditioner
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
from .ldm.cascade.stage_b import StageB from .ldm.cascade.stage_b import StageB
from .ldm.cascade.stage_c import StageC from .ldm.cascade.stage_c import StageC
from .ldm.flux import model as flux_model
from .ldm.hydit.models import HunYuanDiT
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
from .ldm.hydit.models import HunYuanDiT
from .ldm.flux import model as flux_model
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
@ -68,26 +88,33 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config) return ModelSampling(model_config)
TModule = TypeVar('TModule', bound=torch.nn.Module)
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): def __init__(self, model_config, model_type=ModelType.EPS, device: torch.device = None, unet_model: Type[TModule] = UNetModel):
super().__init__() super().__init__()
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype self.manual_cast_dtype = model_config.manual_cast_dtype
self.device: torch.device = device
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None: if model_config.custom_operations is None:
operations = ops.manual_cast if self.manual_cast_dtype is not None:
operations = ops.manual_cast
else:
operations = ops.disable_weight_init
else: else:
operations = ops.disable_weight_init operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if model_management.force_channels_last(): if model_management.force_channels_last():
# todo: ??? # todo: ???
self.diffusion_model.to(memory_format=torch.channels_last) self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model") logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.model_type = model_type self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type) self.model_sampling = model_sampling(model_config, model_type)
@ -96,7 +123,7 @@ class BaseModel(torch.nn.Module):
self.adm_channels = 0 self.adm_channels = 0
self.concat_keys = () self.concat_keys = ()
logging.info("model_type {}".format(model_type.name)) logging.debug("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
@ -669,6 +696,7 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l] sd["{}{}".format(k, l)] = s[l]
return sd return sd
class HunyuanDiT(BaseModel): class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT) super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT)
@ -701,6 +729,7 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out return out
class Flux(BaseModel): class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux) super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux)

View File

@ -136,8 +136,8 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["hidden_size"] = 3072 dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0 dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24 dit_config["num_heads"] = 24
dit_config["depth"] = 19 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = 38 dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56] dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000 dit_config["theta"] = 10000
dit_config["qkv_bias"] = True dit_config["qkv_bias"] = True
@ -494,7 +494,12 @@ def model_config_from_diffusers_unet(state_dict):
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {} out_sd = {}
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
@ -520,7 +525,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
old_weight = out_sd.get(t[0], None) old_weight = out_sd.get(t[0], None)
if old_weight is None: if old_weight is None:
old_weight = torch.empty_like(weight) old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) if old_weight.shape[offset[0]] < offset[1] + offset[2]:
exp = list(weight.shape)
exp[offset[0]] = offset[1] + offset[2]
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
new[:old_weight.shape[0]] = old_weight
old_weight = new
w = old_weight.narrow(offset[0], offset[1], offset[2]) w = old_weight.narrow(offset[0], offset[1], offset[2])
else: else:

View File

@ -571,7 +571,7 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
# fix path representation # fix path representation
local_files = set(f.replace("\\", "/") for f in local_files) local_files = set(f.replace("\\", "/") for f in local_files)
# remove .huggingface # remove .huggingface
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface")) local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
# local_files.issubsetof(repo_files) # local_files.issubsetof(repo_files)
if len(local_files) > 0 and local_files.issubset(repo_files): if len(local_files) > 0 and local_files.issubset(repo_files):
local_dirs_snapshots.append(str(local_path)) local_dirs_snapshots.append(str(local_path))

View File

@ -1,3 +1,20 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations from __future__ import annotations
import logging import logging
@ -143,10 +160,10 @@ if torch.cuda.is_available() and hasattr(torch.version, "hip") and torch.version
logging.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}") logging.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.debug("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try: try:
logging.info("pytorch version: {}".format(torch.version.__version__)) logging.debug("pytorch version: {}".format(torch.version.__version__))
except: except:
pass pass
@ -171,7 +188,7 @@ else:
pass pass
try: try:
XFORMERS_VERSION = xformers.version.__version__ XFORMERS_VERSION = xformers.version.__version__
logging.info("xformers version: {}".format(XFORMERS_VERSION)) logging.debug("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"): if XFORMERS_VERSION.startswith("0.0.18"):
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n") logging.warning("Please downgrade or upgrade xformers to a different version.\n")
@ -263,12 +280,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED vram_state = VRAMState.SHARED
logging.info(f"Set vram state to: {vram_state.name}") logging.debug(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY: if DISABLE_SMART_MEMORY:
logging.info("Disabling smart memory management") logging.debug("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
@ -288,7 +305,7 @@ def get_torch_device_name(device):
try: try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) logging.debug("Device: {}".format(get_torch_device_name(get_torch_device())))
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
@ -315,9 +332,12 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_offloaded_memory(self):
return self.model.model_size() - self.model.loaded_size()
def model_memory_required(self, device): def model_memory_required(self, device):
if device == self.model.current_device: if device == self.model.current_loaded_device():
return 0 return self.model_offloaded_memory()
else: else:
return self.model_memory() return self.model_memory()
@ -329,15 +349,21 @@ class LoadedModel:
load_weights = not self.weights_loaded load_weights = not self.weights_loaded
try: if self.model.loaded_size() > 0:
if lowvram_model_memory > 0 and load_weights: use_more_vram = lowvram_model_memory
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) if use_more_vram == 0:
else: use_more_vram = 1e32
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) self.model_use_more_vram(use_more_vram)
except Exception as e: else:
self.model.unpatch_model(self.model.offload_device) try:
self.model_unload() if lowvram_model_memory > 0 and load_weights:
raise e self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
@ -346,15 +372,24 @@ class LoadedModel:
return self.real_model return self.real_model
def should_reload_model(self, force_patch_weights=False): def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0: if force_patch_weights and self.model.lowvram_patch_counter() > 0:
return True return True
return False return False
def model_unload(self, unpatch_weights=True): def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None self.real_model = None
return True
def model_use_more_vram(self, extra_memory):
return self.model.partially_load(self.device, extra_memory)
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
@ -366,39 +401,59 @@ class LoadedModel:
return f"<LoadedModel>" return f"<LoadedModel>"
def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
if m.device == device:
extra_memory -= m.model_use_more_vram(extra_memory)
if extra_memory <= 0:
break
def offloaded_memory(loaded_models, device):
offloaded_mem = 0
for m in loaded_models:
if m.device == device:
offloaded_mem += m.model_offloaded_memory()
return offloaded_mem
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2 return (1024 * 1024 * 1024) * 1.2
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]: def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock: with model_management_lock:
to_unload = [] return _unload_model_clones(model, unload_weights_only, force_unload)
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return True
same_weights = 0 def _unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
for i in to_unload: to_unload = []
if model.clone_has_same_weights(current_loaded_models[i].model): for i in range(len(current_loaded_models)):
same_weights += 1 if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if same_weights == len(to_unload): if len(to_unload) == 0:
unload_weight = False return True
else:
unload_weight = True
if not force_unload: same_weights = 0
if unload_weights_only and unload_weight == False: for i in to_unload:
return None if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
for i in to_unload: if same_weights == len(to_unload):
logging.debug("unload clone {}{}".format(i, unload_weight)) unload_weight = False
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) else:
unload_weight = True
return unload_weight if not force_unload:
if unload_weights_only and unload_weight == False:
return None
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
@tracer.start_as_current_span("Free Memory") @tracer.start_as_current_span("Free Memory")
@ -406,126 +461,158 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
span = get_current_span() span = get_current_span()
span.set_attribute("memory_required", memory_required) span.set_attribute("memory_required", memory_required)
with model_management_lock: with model_management_lock:
unloaded_models: List[LoadedModel] = [] unloaded_models = _free_memory(memory_required, device, keep_loaded)
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_models.append(i)
for i in sorted(unloaded_models, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_models) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
span.set_attribute("unloaded_models", list(map(str, unloaded_models))) span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
return unloaded_models return unloaded_models
def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
unloaded_model = []
can_unload = []
unloaded_models = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
i = x[-1]
memory_to_free = None
if not DISABLE_SMART_MEMORY:
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
if len(unloaded_model) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
return unloaded_models
@tracer.start_as_current_span("Load Models GPU") @tracer.start_as_current_span("Load Models GPU")
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None) -> None: def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
global vram_state
span = get_current_span() span = get_current_span()
if memory_required != 0: if memory_required != 0:
span.set_attribute("memory_required", memory_required) span.set_attribute("memory_required", memory_required)
with model_management_lock: with model_management_lock:
inference_memory = minimum_inference_memory() _load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
extra_mem = max(inference_memory, memory_required) to_load = list(map(str, models))
if minimum_memory_required is None: span.set_attribute("models", to_load)
minimum_memory_required = extra_mem logging.info(f"Loaded {to_load}")
else:
minimum_memory_required = max(inference_memory, minimum_memory_required)
models = set(models)
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
try: def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
loaded_model_index = current_loaded_models.index(loaded_model) global vram_state
except ValueError:
loaded_model_index = None
if loaded_model_index is not None: inference_memory = minimum_inference_memory()
loaded = current_loaded_models[loaded_model_index] extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic if minimum_memory_required is None:
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) minimum_memory_required = extra_mem
loaded = None else:
else: minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
loaded.currently_used = True
models_already_loaded.append(loaded) models = set(models)
if loaded is None:
models_to_load.append(loaded_model) models_to_load = []
models_already_loaded = []
models_freed = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
models_freed: List[LoadedModel] = []
try: try:
if len(models_to_load) == 0: loaded_model_index = current_loaded_models.index(loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded)) except:
for d in devs: loaded_model_index = None
if d != torch.device("cpu"):
models_freed += free_memory(extra_mem, d, models_already_loaded)
return
total_memory_required = {} if loaded_model_index is not None:
for loaded_model in models_to_load: loaded = current_loaded_models[loaded_model_index]
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
for device in total_memory_required: if loaded is None:
if device != torch.device("cpu"): models_to_load.append(loaded_model)
# todo: where does 1.3 come from?
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load: if len(models_to_load) == 0:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded devs = set(map(lambda a: a.device, models_already_loaded))
if weights_unloaded is not None: for d in devs:
loaded_model.weights_loaded = not weights_unloaded if d != torch.device("cpu"):
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
for loaded_model in models_to_load: free_mem = get_free_memory(d)
model = loaded_model.model if free_mem < minimum_memory_required:
torch_dev = model.load_device models_to_load = free_memory(minimum_memory_required, d)
if is_device_cpu(torch_dev): models_freed += models_to_load
vram_set_state = VRAMState.DISABLED
else: else:
vram_set_state = vram_state use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
lowvram_model_memory = 0 if len(models_to_load) == 0:
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return return
finally:
span.set_attribute("models", list(map(str, models))) total_memory_required = {}
span.set_attribute("models_to_load", list(map(str, models_to_load))) for loaded_model in models_to_load:
span.set_attribute("models_freed", list(map(str, models_freed))) unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) # unload clones where the weights are different
logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}") total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_already_loaded:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required:
if device != torch.device("cpu"):
models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
span = get_current_span()
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2") @_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
@ -605,6 +692,7 @@ def unet_initial_load_device(parameters, dtype):
def maximum_vram_for_weights(device=None): def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory()) return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)): def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
@ -629,12 +717,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
if model_params * 2 > free_model_memory: if model_params * 2 > free_model_memory:
return fp8_dtype return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True): for dt in supported_dtypes:
if torch.float16 in supported_dtypes: if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
return torch.float16 if torch.float16 in supported_dtypes:
if should_use_bf16(device, model_params=model_params, manual_cast=True): return torch.float16
if torch.bfloat16 in supported_dtypes: if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
return torch.bfloat16 if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32 return torch.float32
@ -651,13 +749,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.flo
if bf16_supported and weight_dtype == torch.bfloat16: if bf16_supported and weight_dtype == torch.bfloat16:
return None return None
if fp16_supported and torch.float16 in supported_dtypes: for dt in supported_dtypes:
return torch.float16 if dt == torch.float16 and fp16_supported:
return torch.float16
if dt == torch.bfloat16 and bf16_supported:
return torch.bfloat16
elif bf16_supported and torch.bfloat16 in supported_dtypes: return torch.float32
return torch.bfloat16
else:
return torch.float32
def text_encoder_offload_device(): def text_encoder_offload_device():
@ -679,6 +777,21 @@ def text_encoder_device():
return torch.device("cpu") return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
if is_device_mps(load_device):
return offload_device
mem_l = get_free_memory(load_device)
mem_o = get_free_memory(offload_device)
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
return load_device
else:
return offload_device
def text_encoder_dtype(device=None): def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc: if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn return torch.float8_e4m3fn

View File

@ -1,8 +1,23 @@
from __future__ import annotations from __future__ import annotations
from typing import Protocol, Optional, Any import dataclasses
from typing import Protocol, Optional, TypeVar, runtime_checkable
import torch import torch
import torch.nn
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
@runtime_checkable
class DeviceSettable(Protocol):
@property
def device(self) -> torch.device:
...
@device.setter
def device(self, value: torch.device):
...
class ModelManageable(Protocol): class ModelManageable(Protocol):
@ -22,32 +37,92 @@ class ModelManageable(Protocol):
@property @property
def current_device(self) -> torch.device: def current_device(self) -> torch.device:
... return next(self.model.parameters()).device
def is_clone(self, other: Any) -> bool: def is_clone(self, other: ModelManageableT) -> bool:
... return other.model is self.model
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool: def clone_has_same_weights(self, clone: ModelManageableT) -> bool:
... return clone.model is self.model
def model_size(self) -> int: def model_size(self) -> int:
... from .model_management import module_size
return module_size(self.model)
def model_patches_to(self, arg: torch.device | torch.dtype): def model_patches_to(self, arg: torch.device | torch.dtype):
... pass
def model_dtype(self) -> torch.dtype: def model_dtype(self) -> torch.dtype:
... return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
self.patch_model(device_to=device_to, patch_weights=False)
return self.model
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
"""
Loads the model to the device
:param device_to: the device to move the model weights to
:param patch_weights: True if the patch's weights should also be moved
:return:
"""
... ...
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module: def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
"""
Unloads the model by moving it to the offload device
:param offload_device:
:param unpatch_weights:
:return:
"""
... ...
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: def lowvram_patch_counter(self) -> int:
... return 0
def partially_load(self, device_to: torch.device, extra_memory=0) -> int:
self.patch_model(device_to=device_to)
return self.model_size()
def partially_unload(self, device_to: torch.device, extra_memory=0) -> int:
self.unpatch_model(device_to)
return self.model_size()
def memory_required(self, input_shape) -> int:
from comfy.model_base import BaseModel
if isinstance(self.model, BaseModel):
return self.model.memory_required(input_shape=input_shape)
else:
# todo: why isn't this true?
return self.model_size()
def loaded_size(self) -> int:
if self.current_loaded_device() == self.load_device:
return self.model_size()
return 0
def current_loaded_device(self) -> torch.device:
return self.current_device
@dataclasses.dataclass
class MemoryMeasurements:
model: torch.nn.Module | DeviceSettable
model_loaded_weight_memory: int = 0
lowvram_patch_counter: int = 0
model_lowvram: bool = False
_device: torch.device | None = None
@property @property
def lowvram_patch_counter(self) -> int: def device(self) -> torch.device:
... if isinstance(self.model, DeviceSettable):
return self.model.device
else:
return self._device
@device.setter
def device(self, value: torch.device):
if isinstance(self.model, DeviceSettable):
self.model.device = value
self._device = value

View File

@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import collections
import copy import copy
import inspect import inspect
import logging import logging
@ -5,10 +23,12 @@ import uuid
from typing import Optional from typing import Optional
import torch import torch
import torch.nn
from . import model_management from . import model_management
from . import utils from . import utils
from .model_management_types import ModelManageable from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements
from .types import UnetWrapperFunction from .types import UnetWrapperFunction
@ -69,10 +89,27 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
return model_options return model_options
def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
class ModelPatcher(ModelManageable): class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None): def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size self.size = size
self.model = model self.model: torch.nn.Module = model
self.patches = {} self.patches = {}
self.backup = {} self.backup = {}
self.object_patches = {} self.object_patches = {}
@ -81,25 +118,21 @@ class ModelPatcher(ModelManageable):
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
self._current_device: torch.device
if current_device is None:
self._current_device = self.offload_device
else:
self._current_device = current_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
self.ckpt_name = ckpt_name self.ckpt_name = ckpt_name
self._lowvram_patch_counter = 0 self._memory_measurements = MemoryMeasurements(self.model)
@property @property
def lowvram_patch_counter(self): def model_device(self) -> torch.device:
return self._lowvram_patch_counter return self._memory_measurements.device
@lowvram_patch_counter.setter @model_device.setter
def lowvram_patch_counter(self, value: int): def model_device(self, value: torch.device):
self._lowvram_patch_counter = value self._memory_measurements.device = value
def lowvram_patch_counter(self):
return self._memory_measurements.lowvram_patch_counter
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
@ -107,8 +140,12 @@ class ModelPatcher(ModelManageable):
self.size = model_management.module_size(self.model) self.size = model_management.module_size(self.model)
return self.size return self.size
def loaded_size(self):
return self._memory_measurements.model_loaded_weight_memory
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n._memory_measurements = self._memory_measurements
n.ckpt_name = self.ckpt_name n.ckpt_name = self.ckpt_name
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
@ -122,11 +159,9 @@ class ModelPatcher(ModelManageable):
return n return n
def is_clone(self, other): def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model: return hasattr(other, 'model') and self.model is other.model
return True
return False
def clone_has_same_weights(self, clone): def clone_has_same_weights(self, clone: "ModelPatcher"):
if not self.is_clone(clone): if not self.is_clone(clone):
return False return False
@ -139,7 +174,8 @@ class ModelPatcher(ModelManageable):
else: else:
return True return True
def memory_required(self, input_shape): def memory_required(self, input_shape) -> int:
assert isinstance(self.model, BaseModel)
return self.model.memory_required(input_shape=input_shape) return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
@ -281,16 +317,16 @@ class ModelPatcher(ModelManageable):
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None): def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches: if key not in self.patches:
return return
weight = utils.get_attr(self.model, key) weight = utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
if device_to is not None: if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True) temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
@ -319,31 +355,25 @@ class ModelPatcher(ModelManageable):
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self._current_device = device_to self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = self.model_size()
return self.model return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
lowvram_counter = 0
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
lowvram_weight = False lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = model_management.module_size(m) module_mem = model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory: if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1
if m.comfy_cast_weights:
continue
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
@ -365,15 +395,39 @@ class ModelPatcher(ModelManageable):
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
else: else:
if hasattr(m, "comfy_cast_weights"):
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if hasattr(m, "weight"): if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += model_management.module_size(m) mem_counter += model_management.module_size(m)
param = list(m.parameters())
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
continue
weight_to = None
if full_load: # TODO
weight_to = device_to
self.patch_weight_to_device(weight_key, device_to=weight_to) # TODO: speed this up without OOM
self.patch_weight_to_device(bias_key, device_to=weight_to)
m.to(device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
self.model_lowvram = True if lowvram_counter > 0:
self.lowvram_patch_counter = patch_counter logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self._memory_measurements.model_lowvram = True
else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
self._memory_measurements.model_lowvram = False
self._memory_measurements.lowvram_patch_counter += patch_counter
self._memory_measurements.model_loaded_weight_memory = mem_counter
self.model_device = device_to
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@ -548,31 +602,28 @@ class ModelPatcher(ModelManageable):
def unpatch_model(self, device_to=None, unpatch_weights=True): def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights: if unpatch_weights:
if self.model_lowvram: if self._memory_measurements.model_lowvram:
for m in self.model.modules(): for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"): wipe_lowvram_weight(m)
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
self.model_lowvram = False self._memory_measurements.model_lowvram = False
self.lowvram_patch_counter = 0 self._memory_measurements.lowvram_patch_counter = 0
keys = list(self.backup.keys()) keys = list(self.backup.keys())
if self.weight_inplace_update: for k in keys:
for k in keys: bk = self.backup[k]
utils.copy_to_param(self.model, k, self.backup[k]) if bk.inplace_update:
else: utils.copy_to_param(self.model, k, bk.weight)
for k in keys: else:
utils.set_attr_param(self.model, k, self.backup[k]) utils.set_attr_param(self.model, k, bk.weight)
self.backup.clear() self.backup.clear()
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self._current_device = value = device_to self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = 0
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
@ -580,9 +631,66 @@ class ModelPatcher(ModelManageable):
self.object_patches_backup.clear() self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
utils.copy_to_param(self.model, key, bk.weight)
else:
utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
self._memory_measurements.model_lowvram = True
self._memory_measurements.lowvram_patch_counter += patch_counter
self._memory_measurements.model_loaded_weight_memory -= memory_freed
return memory_freed
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False)
full_load = False
if not self._memory_measurements.model_lowvram:
return 0
if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self._memory_measurements.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self._memory_measurements.model_loaded_weight_memory - current_used
def current_loaded_device(self):
return self.model_device
@property @property
def current_device(self) -> torch.device: def current_device(self) -> torch.device:
return self._current_device return self.current_loaded_device()
def __str__(self): def __str__(self):
if self.ckpt_name is not None: if self.ckpt_name is not None:

View File

@ -823,14 +823,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype): def load_unet(self, unet_name, weight_dtype):
dtype = None model_options = {}
if weight_dtype == "fp8_e4m3fn": if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2": elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS) unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
model = sd.load_unet(unet_path, dtype=dtype) model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader:

View File

@ -91,7 +91,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings): if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
for (duration, module_name, success, new_nodes) in sorted(timings): for (duration, module_name, success, new_nodes) in sorted(timings):
logging.info(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)") logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
if raise_on_failure and len(exceptions) > 0: if raise_on_failure and len(exceptions) > 0:
try: try:
raise ExceptionGroup("Node import failed", exceptions) raise ExceptionGroup("Node import failed", exceptions)

View File

@ -20,21 +20,19 @@ from . import model_sampling
from . import sd1_clip from . import sd1_clip
from . import sdxl_clip from . import sdxl_clip
from . import utils from . import utils
from .model_management import load_models_gpu
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import hydit
from .text_encoders import sa_t5
from .text_encoders import aura_t5
from .text_encoders import flux
from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.audio.autoencoder import AudioOobleckVAE
from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .model_management import load_models_gpu
from .t2i_adapter import adapter from .t2i_adapter import adapter
from .taesd import taesd from .taesd import taesd
from .text_encoders import aura_t5
from .text_encoders import flux
from .text_encoders import hydit
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip): def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
@ -68,7 +66,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, target: CLIPTarget=None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None=None): def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0):
if tokenizer_data is None: if tokenizer_data is None:
tokenizer_data = dict() tokenizer_data = dict()
if no_init: if no_init:
@ -79,9 +77,9 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
dtype = model_management.text_encoder_dtype(load_device) dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
if "textmodel_json_config" not in params and textmodel_json_config is not None: if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config params['textmodel_json_config'] = textmodel_json_config
@ -90,11 +88,16 @@ class CLIP:
for dt in self.cond_stage_model.dtypes: for dt in self.cond_stage_model.dtypes:
if not model_management.supports_cast(load_device, dt): if not model_management.supports_cast(load_device, dt):
load_device = offload_device load_device = offload_device
if params['device'] != offload_device:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
@ -476,7 +479,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sd3_clip.SD3ClipModel clip_target.clip = sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = sd3_clip.SD3Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config) parameters = 0
for c in clip_data:
parameters += utils.calculate_parameters(c)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters)
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
@ -523,15 +530,21 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
sd = utils.load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, ckpt_path=ckpt_path)
if out is None:
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, ckpt_path: str | None = None):
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
model = None model = None
_model_patcher = None _model_patcher = None
clip_target = None inital_load_device = None
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = utils.calculate_parameters(sd, diffusion_model_prefix) parameters = utils.calculate_parameters(sd, diffusion_model_prefix)
@ -540,13 +553,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return None
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None: if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype) unet_weight_dtype.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("weight_dtype", None)
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@ -570,7 +588,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if clip_target is not None: if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd) parameters = utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
m, u = clip.load_sd(clip_sd, full_model=True) m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0: if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
@ -589,14 +608,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.debug("left over keys: {}".format(left_over)) logging.debug("left over keys: {}".format(left_over))
if output_model: if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path)) _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), ckpt_name=os.path.basename(ckpt_path))
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
load_models_gpu([_model_patcher]) model_management.load_models_gpu([_model_patcher], force_full_load=True)
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular format def load_diffusion_model_state_dict(sd, model_options: dict = None): # load unet in diffusers or regular format
if model_options is None:
model_options = {}
dtype = model_options.get("dtype", None)
# Allow loading unets from checkpoint files # Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@ -638,6 +660,7 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
@ -647,15 +670,27 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path, dtype=None): def load_diffusion_model(unet_path, model_options: dict = None):
if model_options is None:
model_options = {}
sd = utils.load_torch_file(unet_path) sd = utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd, dtype=dtype) model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None clip_sd = None
load_models = [model] load_models = [model]

View File

@ -16,7 +16,7 @@ except ImportError:
from typing import Tuple, Sequence, TypeVar, Callable from typing import Tuple, Sequence, TypeVar, Callable
import torch import torch
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin from transformers import CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model from . import clip_model
from . import model_management from . import model_management
@ -66,7 +66,7 @@ class ClipTokenWeightEncoder:
output = [] output = []
for k in range(0, sections): for k in range(0, sections):
z = out[k:k+1] z = out[k:k + 1]
if has_weights: if has_weights:
z_empty = out[-1] z_empty = out[-1]
for i in range(len(z)): for i in range(len(z)):
@ -112,7 +112,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__) config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
self.operations = ops.manual_cast self.operations = ops.manual_cast
self.transformer = model_class(config, dtype, device, self.operations) self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers self.num_layers = self.transformer.num_layers
@ -389,6 +388,18 @@ def expand_directory_list(directories):
return list(dirs) return list(dirs)
def bundled_embed(embed, prefix, suffix): # bundled embedding in lora format
i = 0
out_list = []
for k in embed:
if k.startswith(prefix) and k.endswith(suffix):
out_list.append(embed[k])
if len(out_list) == 0:
return None
return torch.cat(out_list, dim=0)
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None): def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
@ -455,8 +466,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
elif embed_key is not None and embed_key in embed: elif embed_key is not None and embed_key in embed:
embed_out = embed[embed_key] embed_out = embed[embed_key]
else: else:
values = embed.values() embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
embed_out = next(iter(values)) if embed_out is None:
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
if embed_out is None:
values = embed.values()
embed_out = next(iter(values))
return embed_out return embed_out
@ -631,6 +646,7 @@ class SDTokenizer:
def state_dict(self): def state_dict(self):
return {} return {}
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer") SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
@ -664,6 +680,7 @@ class SD1Tokenizer:
def state_dict(self): def state_dict(self):
return {} return {}
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs): def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
super().__init__() super().__init__()

View File

@ -640,9 +640,9 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
memory_usage_factor = 2.6 memory_usage_factor = 2.8
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]

View File

@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
from . import model_base from . import model_base
from . import utils from . import utils
@ -30,6 +48,7 @@ class BASE:
memory_usage_factor = 2.0 memory_usage_factor = 2.0
manual_cast_dtype = None manual_cast_dtype = None
custom_operations = None
@classmethod @classmethod
def matches(s, unet_config, state_dict=None): def matches(s, unet_config, state_dict=None):

View File

@ -1,3 +1,20 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
@ -472,8 +489,33 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight", k = "{}.attn.".format(prefix_from)
"attn.to_out.0.bias": "img_attn.proj.bias", qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
} }
for k in block_map: for k in block_map:
@ -489,15 +531,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size)) key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
block_map = {#TODO block_map = {
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
} }
for k in block_map: for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
MAP_BASIC = { #TODO MAP_BASIC = {
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.linear.weight", "proj_out.weight"),
("img_in.bias", "x_embedder.bias"),
("img_in.weight", "x_embedder.weight"),
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
("txt_in.bias", "context_embedder.bias"),
("txt_in.weight", "context_embedder.weight"),
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
} }
for k in MAP_BASIC: for k in MAP_BASIC:
@ -859,6 +927,8 @@ def seed_for_block(seed):
numpy_rng_state = np.random.get_state() numpy_rng_state = np.random.get_state()
if torch.cuda.is_available(): if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all() cuda_rng_state = torch.cuda.get_rng_state_all()
else:
cuda_rng_state = None
# Set the new seed # Set the new seed
torch.manual_seed(seed) torch.manual_seed(seed)

View File

@ -19,6 +19,7 @@ class CLIPTextEncodeHunyuanDiT:
cond = output.pop("cond") cond = output.pop("cond")
return ([[cond, output]], ) return ([[cond, output]], )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
} }

View File

@ -11,9 +11,10 @@ from typing import Any, Dict, Optional, List, Callable, Union
import torch import torch
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \ PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel, AutoModelForCausalLM
from typing_extensions import TypedDict from typing_extensions import TypedDict
from comfy import model_management
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel from comfy.language.transformers_model_management import TransformersManagedModel
@ -28,9 +29,9 @@ _AUTO_CHAT_TEMPLATE = "default"
try: try:
from llava import model from llava import model
logging.info("Additional LLaVA models are now supported") logging.debug("Additional LLaVA models are now supported")
except ImportError as exc: except ImportError as exc:
logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support") logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
# aka kwargs type # aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any] _GENERATION_KWARGS_TYPE = Dict[str, Any]
@ -129,7 +130,7 @@ class TransformersGenerationConfig(CustomNode):
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
return { return {
"required": { "required": {
"model": ("MODEL",) "model": ("MODEL", {})
} }
} }
@ -247,13 +248,22 @@ class TransformersLoader(CustomNode):
**hub_kwargs **hub_kwargs
} }
try: # try:
model = AutoModel.from_pretrained(**from_pretrained_kwargs) # import flash_attn
except Exception as exc_info: # from_pretrained_kwargs["attn_implementation"] = "flash_attention_2"
# not yet supported by automodel # except ImportError:
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs) # logging.debug("install flash_attn for improved performance using language nodes")
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs) config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
if config_dict["model_type"] == "llava_next":
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
else:
try:
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
except Exception:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs)
try: try:
try: try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs) processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
@ -265,6 +275,10 @@ class TransformersLoader(CustomNode):
processor = None processor = None
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs) tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
logging.debug("enabled xformers memory efficient attention")
model_managed = TransformersManagedModel( model_managed = TransformersManagedModel(
repo_id=ckpt_name, repo_id=ckpt_name,
model=model, model=model,

View File

@ -263,6 +263,7 @@ class CLIPSave:
metadata = {} metadata = {}
if not args.disable_metadata: if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info metadata["prompt"] = prompt_info
if extra_pnginfo is not None: if extra_pnginfo is not None:
for x in extra_pnginfo: for x in extra_pnginfo:

View File

@ -161,7 +161,7 @@ class BooleanRequestParameter(CustomNode):
} }
} }
RETURN_TYPES = ("STRING",) RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "execute" FUNCTION = "execute"
CATEGORY = "api/openapi" CATEGORY = "api/openapi"

View File

@ -108,3 +108,8 @@ NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, "CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3, "ControlNetApplySD3": ControlNetApplySD3,
} }
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
}

View File

@ -16,7 +16,7 @@ try:
from spandrel import MAIN_REGISTRY from spandrel import MAIN_REGISTRY
MAIN_REGISTRY.add(*EXTRA_REGISTRY) MAIN_REGISTRY.add(*EXTRA_REGISTRY)
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.") logging.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
except: except:
pass pass
@ -26,20 +26,16 @@ class UpscaleModelManageable(ModelManageable):
self.ckpt_name = ckpt_name self.ckpt_name = ckpt_name
self.model_descriptor = model_descriptor self.model_descriptor = model_descriptor
self.model = model_descriptor.model self.model = model_descriptor.model
self.load_device = model_management.unet_offload_device() self.load_device = model_management.get_torch_device()
self.offload_device = model_management.unet_offload_device() self.offload_device = model_management.unet_offload_device()
self._current_device = self.offload_device self._input_size = (1, 512, 512)
self._lowvram_patch_counter = 0
# Private properties for image sizes and channels
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
self._input_channels = model_descriptor.input_channels self._input_channels = model_descriptor.input_channels
self._output_channels = model_descriptor.output_channels self._output_channels = model_descriptor.output_channels
self.tile = 512 self.tile = 512
@property @property
def current_device(self) -> torch.device: def current_device(self) -> torch.device:
return self._current_device return self.model_descriptor.device
@property @property
def input_size(self) -> tuple[int, int, int]: def input_size(self) -> tuple[int, int, int]:
@ -65,21 +61,14 @@ class UpscaleModelManageable(ModelManageable):
def is_clone(self, other: Any) -> bool: def is_clone(self, other: Any) -> bool:
return isinstance(other, UpscaleModelManageable) and self.model is other.model return isinstance(other, UpscaleModelManageable) and self.model is other.model
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool: def clone_has_same_weights(self, clone) -> bool:
return self.is_clone(clone) return self.is_clone(clone)
def model_size(self) -> int: def model_size(self) -> int:
# Calculate the size of the model parameters
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters()) model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
# Get the byte size of the model's dtype
dtype_size = torch.finfo(self.model_dtype()).bits // 8 dtype_size = torch.finfo(self.model_dtype()).bits // 8
# Calculate the memory required for input and output images
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
# Add some extra memory for processing
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
return model_params_size + input_size + output_size + extra_memory return model_params_size + input_size + output_size + extra_memory
@ -95,30 +84,16 @@ class UpscaleModelManageable(ModelManageable):
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=device_to) self.model.to(device=device_to)
self._current_device = device_to
self._lowvram_patch_counter += 1
return self.model return self.model
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module: def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
if patch_weights: self.model.to(device=device_to)
self.model.to(device=device_to)
self._current_device = device_to
return self.model return self.model
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
if unpatch_weights: self.model.to(device=offload_device)
self.model.to(device=offload_device)
self._current_device = offload_device
return self.model return self.model
@property
def lowvram_patch_counter(self) -> int:
return self._lowvram_patch_counter
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
self._lowvram_patch_counter = value
def __str__(self): def __str__(self):
if self.ckpt_name is not None: if self.ckpt_name is not None:
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>" return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"

View File

@ -46,7 +46,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@pytest.mark.asyncio @pytest.mark.asyncio
def test_known_repos(tmp_path_factory): async def test_known_repos(tmp_path_factory):
prev_hub_cache = os.getenv("HF_HUB_CACHE") prev_hub_cache = os.getenv("HF_HUB_CACHE")
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache")) os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))

View File

@ -39,12 +39,17 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
workflow = json.loads(workflow_file.read_text()) workflow = json.loads(workflow_file.read_text())
prompt = Prompt.validate(workflow) prompt = Prompt.validate(workflow)
# todo: add all the models we want to test a bit more elegantly # todo: add all the models we want to test a bit m2ore elegantly
outputs = await client.queue_prompt(prompt) outputs = await client.queue_prompt(prompt)
if any(v.class_type == "SaveImage" for v in prompt.values()): if any(v.class_type == "SaveImage" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage") save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
elif any(v.class_type == "SaveAudio" for v in prompt.values()): elif any(v.class_type == "SaveAudio" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio") save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
assert outputs[save_image_node_id]["audio"][0]["filename"] is not None assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
elif any(v.class_type == "PreviewString" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
output_str = outputs[save_image_node_id]["string"][0]
assert output_str is not None
assert len(output_str) > 0

View File

@ -0,0 +1,53 @@
{
"17": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"19": {
"inputs": {
"model_name": "RealESRGAN_x4plus.pth"
},
"class_type": "UpscaleModelLoader",
"_meta": {
"title": "Load Upscale Model"
}
},
"20": {
"inputs": {
"upscale_model": [
"19",
0
],
"image": [
"17",
0
]
},
"class_type": "ImageUpscaleWithModel",
"_meta": {
"title": "Upscale Image (using Model)"
}
},
"21": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"20",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}

View File

@ -0,0 +1,91 @@
{
"1": {
"inputs": {
"ckpt_name": "llava-hf/llava-v1.6-mistral-7b-hf",
"subfolder": ""
},
"class_type": "TransformersLoader",
"_meta": {
"title": "TransformersLoader"
}
},
"3": {
"inputs": {
"max_new_tokens": 512,
"repetition_penalty": 0,
"seed": 2013744903,
"use_cache": true,
"__tokens": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. </s>",
"model": [
"1",
0
],
"tokens": [
"4",
0
]
},
"class_type": "TransformersGenerate",
"_meta": {
"title": "TransformersGenerate"
}
},
"4": {
"inputs": {
"prompt": "Who is this?",
"chat_template": "llava-v1.6-mistral-7b-hf",
"model": [
"1",
0
],
"images": [
"8",
0
]
},
"class_type": "OneShotInstructTokenize",
"_meta": {
"title": "OneShotInstructTokenize"
}
},
"5": {
"inputs": {
"value": [
"3",
0
],
"output": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. "
},
"class_type": "PreviewString",
"_meta": {
"title": "PreviewString"
}
},
"6": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"8": {
"inputs": {
"upscale_method": "nearest-exact",
"megapixels": 1,
"image": [
"6",
0
]
},
"class_type": "ImageScaleToTotalPixels",
"_meta": {
"title": "ImageScaleToTotalPixels"
}
}
}

View File

@ -0,0 +1,60 @@
{
"1": {
"inputs": {
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct",
"subfolder": ""
},
"class_type": "TransformersLoader",
"_meta": {
"title": "TransformersLoader"
}
},
"3": {
"inputs": {
"max_new_tokens": 512,
"repetition_penalty": 0,
"seed": 2514389986,
"use_cache": true,
"__tokens": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple.<|end|>",
"model": [
"1",
0
],
"tokens": [
"4",
0
]
},
"class_type": "TransformersGenerate",
"_meta": {
"title": "TransformersGenerate"
}
},
"4": {
"inputs": {
"prompt": "What comes after apple?",
"chat_template": "phi-3",
"model": [
"1",
0
]
},
"class_type": "OneShotInstructTokenize",
"_meta": {
"title": "OneShotInstructTokenize"
}
},
"5": {
"inputs": {
"value": [
"3",
0
],
"output": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple."
},
"class_type": "PreviewString",
"_meta": {
"title": "PreviewString"
}
}
}

View File

@ -0,0 +1,223 @@
{
"1": {
"inputs": {
"ckpt_name": "sd_xl_base_1.0.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"2": {
"inputs": {
"strength": 0.5,
"start_percent": 0,
"end_percent": 1,
"positive": [
"3",
0
],
"negative": [
"6",
0
],
"control_net": [
"7",
0
],
"image": [
"9",
0
]
},
"class_type": "ControlNetApplyAdvanced",
"_meta": {
"title": "Apply ControlNet (Advanced)"
}
},
"3": {
"inputs": {
"text": "a girl with blue hair",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"5": {
"inputs": {
"add_noise": true,
"noise_seed": 969970429360105,
"cfg": 8,
"model": [
"1",
0
],
"positive": [
"2",
0
],
"negative": [
"2",
1
],
"sampler": [
"13",
0
],
"sigmas": [
"11",
0
],
"latent_image": [
"12",
0
]
},
"class_type": "SamplerCustom",
"_meta": {
"title": "SamplerCustom"
}
},
"6": {
"inputs": {
"text": "",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"type": "canny/lineart/anime_lineart/mlsd",
"control_net": [
"8",
0
]
},
"class_type": "SetUnionControlNetType",
"_meta": {
"title": "SetUnionControlNetType"
}
},
"8": {
"inputs": {
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
},
"class_type": "ControlNetLoader",
"_meta": {
"title": "Load ControlNet Model"
}
},
"9": {
"inputs": {
"low_threshold": 0.4,
"high_threshold": 0.8,
"image": [
"18",
0
]
},
"class_type": "Canny",
"_meta": {
"title": "Canny"
}
},
"11": {
"inputs": {
"model_type": "SDXL",
"steps": 25,
"denoise": 1
},
"class_type": "AlignYourStepsScheduler",
"_meta": {
"title": "AlignYourStepsScheduler"
}
},
"12": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"13": {
"inputs": {
"eta": 1,
"s_noise": 1
},
"class_type": "SamplerEulerAncestral",
"_meta": {
"title": "SamplerEulerAncestral"
}
},
"14": {
"inputs": {
"samples": [
"5",
0
],
"vae": [
"1",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"15": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"14",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"17": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"18": {
"inputs": {
"upscale_method": "lanczos",
"megapixels": 1,
"image": [
"17",
0
]
},
"class_type": "ImageScaleToTotalPixels",
"_meta": {
"title": "ImageScaleToTotalPixels"
}
}
}

View File

@ -0,0 +1,302 @@
{
"1": {
"inputs": {
"ckpt_name": "sd_xl_base_1.0.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"2": {
"inputs": {
"strength": 0.5,
"start_percent": 0,
"end_percent": 1,
"positive": [
"3",
0
],
"negative": [
"6",
0
],
"control_net": [
"28",
0
],
"image": [
"9",
0
]
},
"class_type": "ControlNetApplyAdvanced",
"_meta": {
"title": "Apply ControlNet (Advanced)"
}
},
"3": {
"inputs": {
"text": [
"26",
0
],
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"5": {
"inputs": {
"add_noise": [
"23",
0
],
"noise_seed": [
"20",
0
],
"cfg": [
"19",
0
],
"model": [
"1",
0
],
"positive": [
"2",
0
],
"negative": [
"2",
1
],
"sampler": [
"24",
0
],
"sigmas": [
"11",
0
],
"latent_image": [
"12",
0
]
},
"class_type": "SamplerCustom",
"_meta": {
"title": "SamplerCustom"
}
},
"6": {
"inputs": {
"text": "",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"9": {
"inputs": {
"low_threshold": 0.4,
"high_threshold": 0.8,
"image": [
"18",
0
]
},
"class_type": "Canny",
"_meta": {
"title": "Canny"
}
},
"11": {
"inputs": {
"model_type": "SDXL",
"steps": 25,
"denoise": 1
},
"class_type": "AlignYourStepsScheduler",
"_meta": {
"title": "AlignYourStepsScheduler"
}
},
"12": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"14": {
"inputs": {
"samples": [
"5",
0
],
"vae": [
"1",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"15": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"14",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"17": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"18": {
"inputs": {
"upscale_method": "lanczos",
"megapixels": 1,
"image": [
"17",
0
]
},
"class_type": "ImageScaleToTotalPixels",
"_meta": {
"title": "ImageScaleToTotalPixels"
}
},
"19": {
"inputs": {
"value": 8,
"name": "cfg",
"title": "",
"description": "",
"__required": true
},
"class_type": "FloatRequestParameter",
"_meta": {
"title": "FloatRequestParameter"
}
},
"20": {
"inputs": {
"value": 0,
"name": "seed",
"title": "",
"description": "",
"__required": true
},
"class_type": "IntRequestParameter",
"_meta": {
"title": "IntRequestParameter"
}
},
"23": {
"inputs": {
"value": true,
"name": "add_noise",
"title": "",
"description": "",
"__required": true
},
"class_type": "BooleanRequestParameter",
"_meta": {
"title": "BooleanRequestParameter"
}
},
"24": {
"inputs": {
"sampler_name": [
"25",
0
]
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"25": {
"inputs": {
"value": "euler",
"name": "sampler_name",
"title": "",
"description": "",
"__required": true
},
"class_type": "StringEnumRequestParameter",
"_meta": {
"title": "StringEnumRequestParameter"
}
},
"26": {
"inputs": {
"value": "a girl with blue hair",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "StringRequestParameter",
"_meta": {
"title": "StringRequestParameter"
}
},
"27": {
"inputs": {
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
},
"class_type": "ControlNetLoader",
"_meta": {
"title": "Load ControlNet Model"
}
},
"28": {
"inputs": {
"type": "canny/lineart/anime_lineart/mlsd",
"control_net": [
"27",
0
]
},
"class_type": "SetUnionControlNetType",
"_meta": {
"title": "SetUnionControlNetType"
}
}
}