mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 22:00:19 +08:00
Merge commit '39fb74c5bd13a1dccf4d7293a2f7a755d9f43cbd' of github.com:comfyanonymous/ComfyUI
- Improvements to tests - Fixes model management - Fixes issues with language nodes
This commit is contained in:
commit
0549f35e85
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
|
||||||
|
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||||
|
name: Pull Request CI Workflow Runs
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pr-test-stable:
|
||||||
|
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos, linux, windows]
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["stable"]
|
||||||
|
include:
|
||||||
|
- os: macos
|
||||||
|
runner_label: [self-hosted, macOS]
|
||||||
|
flags: "--use-pytorch-cross-attention"
|
||||||
|
- os: linux
|
||||||
|
runner_label: [self-hosted, Linux]
|
||||||
|
flags: ""
|
||||||
|
- os: windows
|
||||||
|
runner_label: [self-hosted, win]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
use_prior_commit: 'true'
|
||||||
|
comment:
|
||||||
|
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@v6
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
github.rest.issues.createComment({
|
||||||
|
issue_number: context.issue.number,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||||
|
})
|
||||||
95
.github/workflows/test-ci.yml
vendored
Normal file
95
.github/workflows/test-ci.yml
vendored
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
|
||||||
|
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||||
|
name: Full Comfy CI Workflow Runs
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths-ignore:
|
||||||
|
- 'app/**'
|
||||||
|
- 'input/**'
|
||||||
|
- 'output/**'
|
||||||
|
- 'notebooks/**'
|
||||||
|
- 'script_examples/**'
|
||||||
|
- '.github/**'
|
||||||
|
- 'web/**'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-stable:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos, linux, windows]
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["stable"]
|
||||||
|
include:
|
||||||
|
- os: macos
|
||||||
|
runner_label: [self-hosted, macOS]
|
||||||
|
flags: "--use-pytorch-cross-attention"
|
||||||
|
- os: linux
|
||||||
|
runner_label: [self-hosted, Linux]
|
||||||
|
flags: ""
|
||||||
|
- os: windows
|
||||||
|
runner_label: [self-hosted, win]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
|
||||||
|
test-win-nightly:
|
||||||
|
strategy:
|
||||||
|
fail-fast: true
|
||||||
|
matrix:
|
||||||
|
os: [windows]
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["nightly"]
|
||||||
|
include:
|
||||||
|
- os: windows
|
||||||
|
runner_label: [self-hosted, win]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
|
||||||
|
test-unix-nightly:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos, linux]
|
||||||
|
python_version: ["3.11"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["nightly"]
|
||||||
|
include:
|
||||||
|
- os: macos
|
||||||
|
runner_label: [self-hosted, macOS]
|
||||||
|
flags: "--use-pytorch-cross-attention"
|
||||||
|
- os: linux
|
||||||
|
runner_label: [self-hosted, Linux]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -176,4 +176,5 @@ cython_debug/
|
|||||||
/tests-ui/data/object_info.json
|
/tests-ui/data/object_info.json
|
||||||
/user/
|
/user/
|
||||||
*.log
|
*.log
|
||||||
web_custom_versions/
|
web_custom_versions/
|
||||||
|
.DS_Store
|
||||||
|
|||||||
@ -14,9 +14,9 @@ from typing import TypedDict
|
|||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from ..cli_args import DEFAULT_VERSION_STRING
|
||||||
from comfy.cmd.folder_paths import add_model_folder_path
|
from ..cmd.folder_paths import add_model_folder_path
|
||||||
from comfy.component_model.files import get_package_as_path
|
from ..component_model.files import get_package_as_path
|
||||||
|
|
||||||
REQUEST_TIMEOUT = 10 # seconds
|
REQUEST_TIMEOUT = 10 # seconds
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import PurePath
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import lazy_object_proxy
|
import lazy_object_proxy
|
||||||
@ -34,9 +33,11 @@ from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOption
|
|||||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||||
# various functions that are declared here. It should have been a context in the first place.
|
# various functions that are declared here. It should have been a context in the first place.
|
||||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||||
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
|
||||||
from comfy.graph_utils import is_link, GraphBuilder
|
# order matters
|
||||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
|
from ..graph_utils import is_link, GraphBuilder
|
||||||
|
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
|
|
||||||
|
|
||||||
class IsChangedCache:
|
class IsChangedCache:
|
||||||
@ -446,6 +447,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
"traceback": traceback.format_tb(tb),
|
"traceback": traceback.format_tb(tb),
|
||||||
"current_inputs": input_data_formatted
|
"current_inputs": input_data_formatted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isinstance(ex, model_management.OOM_EXCEPTION):
|
||||||
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
|
model_management.unload_all_models()
|
||||||
|
|
||||||
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
|
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
|
||||||
|
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|||||||
@ -55,7 +55,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
|||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.debug("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|||||||
@ -838,7 +838,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.info("Starting server\n")
|
logging.info("Starting server")
|
||||||
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
|
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(address, port)
|
call_on_start(address, port)
|
||||||
|
|||||||
@ -1,4 +1,24 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
@ -12,7 +32,8 @@ from . import latent_formats
|
|||||||
|
|
||||||
from .cldm import cldm, mmdit
|
from .cldm import cldm, mmdit
|
||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
from .ldm.cascade import controlnet
|
from .ldm import hydit, flux
|
||||||
|
from .ldm.cascade import controlnet as cascade_controlnet
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@ -33,6 +54,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
else:
|
else:
|
||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
|
class StrengthType(Enum):
|
||||||
|
CONSTANT = 1
|
||||||
|
LINEAR_UP = 2
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
@ -51,6 +76,8 @@ class ControlBase:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
self.extra_conds = []
|
||||||
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -93,6 +120,8 @@ class ControlBase:
|
|||||||
c.latent_format = self.latent_format
|
c.latent_format = self.latent_format
|
||||||
c.extra_args = self.extra_args.copy()
|
c.extra_args = self.extra_args.copy()
|
||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
|
c.extra_conds = self.extra_conds.copy()
|
||||||
|
c.strength_type = self.strength_type
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@ -113,7 +142,10 @@ class ControlBase:
|
|||||||
|
|
||||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||||
applied_to.add(x)
|
applied_to.add(x)
|
||||||
x *= self.strength
|
if self.strength_type == StrengthType.CONSTANT:
|
||||||
|
x *= self.strength
|
||||||
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
@ -142,7 +174,7 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, ckpt_name: str = None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@ -154,6 +186,8 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
|
self.extra_conds += extra_conds
|
||||||
|
self.strength_type = strength_type
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -191,13 +225,16 @@ class ControlNet(ControlBase):
|
|||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||||
y = cond.get('y', None)
|
extra = self.extra_args.copy()
|
||||||
if y is not None:
|
for c in self.extra_conds:
|
||||||
y = y.to(dtype)
|
temp = cond.get(c, None)
|
||||||
|
if temp is not None:
|
||||||
|
extra[c] = temp.to(dtype)
|
||||||
|
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -286,6 +323,7 @@ class ControlLora(ControlNet):
|
|||||||
ControlBase.__init__(self, device)
|
ControlBase.__init__(self, device)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
self.extra_conds += ["y"]
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
super().pre_run(model, percent_to_timestep_function)
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
@ -338,12 +376,8 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def controlnet_config(sd):
|
||||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "", True)
|
||||||
model_config = model_detection.model_config_from_unet(new_sd, "", True)
|
|
||||||
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
|
||||||
for k in sd:
|
|
||||||
new_sd[k] = sd[k]
|
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||||
|
|
||||||
@ -356,14 +390,27 @@ def load_controlnet_mmdit(sd):
|
|||||||
else:
|
else:
|
||||||
operations = ops.disable_weight_init
|
operations = ops.disable_weight_init
|
||||||
|
|
||||||
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
|
||||||
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
return control_model
|
||||||
|
|
||||||
|
def load_controlnet_mmdit(sd):
|
||||||
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||||
|
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
|
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = latent_formats.SD3()
|
latent_format = latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
@ -371,8 +418,30 @@ def load_controlnet_mmdit(sd):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_hunyuandit(controlnet_data):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||||
|
|
||||||
|
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
|
latent_format = latent_formats.SDXL()
|
||||||
|
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_xlabs(sd):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
||||||
|
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
|
return load_controlnet_hunyuandit(controlnet_data)
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data)
|
||||||
|
|
||||||
@ -430,7 +499,10 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
|
return load_controlnet_flux_xlabs(controlnet_data)
|
||||||
|
else:
|
||||||
|
return load_controlnet_mmdit(controlnet_data)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
@ -590,11 +662,11 @@ def load_t2i_adapter(t2i_data):
|
|||||||
xl = True
|
xl = True
|
||||||
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||||
elif "backbone.0.0.weight" in keys:
|
elif "backbone.0.0.weight" in keys:
|
||||||
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||||
compression_ratio = 32
|
compression_ratio = 32
|
||||||
upscale_algorithm = 'bilinear'
|
upscale_algorithm = 'bilinear'
|
||||||
elif "backbone.10.blocks.0.weight" in keys:
|
elif "backbone.10.blocks.0.weight" in keys:
|
||||||
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||||
compression_ratio = 1
|
compression_ratio = 1
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
|||||||
|
|
||||||
unet = None
|
unet = None
|
||||||
if unet_path is not None:
|
if unet_path is not None:
|
||||||
unet = sd.load_unet(unet_path)
|
unet = sd.load_diffusion_model(unet_path)
|
||||||
|
|
||||||
clip = None
|
clip = None
|
||||||
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])
|
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from tqdm.auto import trange, tqdm
|
|||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
import comfy.model_patcher
|
from .. import model_patcher
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
@ -1032,7 +1032,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
@ -1058,7 +1058,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
|||||||
@ -77,19 +77,9 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
|
|
||||||
return self.model.config.to_dict()
|
return self.model.config.to_dict()
|
||||||
|
|
||||||
@property
|
|
||||||
def lowvram_patch_counter(self):
|
def lowvram_patch_counter(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@lowvram_patch_counter.setter
|
|
||||||
def lowvram_patch_counter(self, value: int):
|
|
||||||
warnings.warn("Not supported")
|
|
||||||
pass
|
|
||||||
|
|
||||||
load_device: torch.device
|
|
||||||
offload_device: torch.device
|
|
||||||
model: PreTrainedModel
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device:
|
||||||
return self.model.device
|
return self.model.device
|
||||||
@ -127,12 +117,10 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||||
return self.model.to(device=device_to)
|
return self.model.to(device=device_to)
|
||||||
|
|
||||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
|
||||||
return self.model.to(device=device_to)
|
return self.model.to(device=device_to)
|
||||||
|
|
||||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
|
||||||
return self.model.to(device=offload_device)
|
return self.model.to(device=offload_device)
|
||||||
|
|
||||||
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
|
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
|
||||||
@ -177,7 +165,7 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
self.processor.to(device=self.load_device)
|
self.processor.to(device=self.load_device)
|
||||||
|
|
||||||
assert "<image>" in prompt.lower(), "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
assert "<image>" in prompt.lower(), "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
||||||
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
|
batch_feature: BatchFeature = self.processor([prompt], images=images.unbind(), padding=True, return_tensors="pt")
|
||||||
if hasattr(self.processor, "to"):
|
if hasattr(self.processor, "to"):
|
||||||
self.processor.to(device=self.offload_device)
|
self.processor.to(device=self.offload_device)
|
||||||
assert "input_ids" in batch_feature
|
assert "input_ids" in batch_feature
|
||||||
|
|||||||
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
# add ControlNet blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
# controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||||
|
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
block_res_samples = ()
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
block_res_samples = block_res_samples + (img,)
|
||||||
|
|
||||||
|
controlnet_block_res_samples = ()
|
||||||
|
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||||
|
block_res_sample = controlnet_block(block_res_sample)
|
||||||
|
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||||
|
|
||||||
|
return {"output": (controlnet_block_res_samples * 10)[:19]}
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
patch_size = 2
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
||||||
@ -2,12 +2,12 @@ import math
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
from ... import ops
|
from ... import ops
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|||||||
"""
|
"""
|
||||||
t = time_factor * t
|
t = time_factor * t
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||||
t.device
|
|
||||||
)
|
|
||||||
|
|
||||||
args = t[:, None].float() * freqs[None]
|
args = t[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
@ -48,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|||||||
embedding = embedding.to(t)
|
embedding = embedding.to(t)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
class MLPEmbedder(nn.Module):
|
class MLPEmbedder(nn.Module):
|
||||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -94,14 +91,6 @@ class SelfAttention(nn.Module):
|
|||||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
|
||||||
qkv = self.qkv(x)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
||||||
q, k = self.norm(q, k, v)
|
|
||||||
x = attention(q, k, v, pe=pe)
|
|
||||||
x = self.proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModulationOut:
|
class ModulationOut:
|
||||||
@ -163,22 +152,21 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
# run actual attention
|
# run actual attention
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||||
|
|
||||||
attn = attention(q, k, v, pe=pe)
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
@ -186,8 +174,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
|
if txt.dtype == torch.float16:
|
||||||
|
txt = txt.clip(-65504, 65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
@ -232,14 +224,17 @@ class SingleStreamBlock(nn.Module):
|
|||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
return x + mod.gate * output
|
x += mod.gate * output
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
x = x.clip(-65504, 65504)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LastLayer(nn.Module):
|
class LastLayer(nn.Module):
|
||||||
|
|||||||
@ -38,7 +38,7 @@ class Flux(nn.Module):
|
|||||||
Transformer model for flow matching on sequences.
|
Transformer model for flow matching on sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
params = FluxParams(**kwargs)
|
params = FluxParams(**kwargs)
|
||||||
@ -83,7 +83,8 @@ class Flux(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
if final_layer:
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@ -94,6 +95,7 @@ class Flux(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
|
control=None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@ -112,8 +114,15 @@ class Flux(nn.Module):
|
|||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
for block in self.double_blocks:
|
for i in range(len(self.double_blocks)):
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
if control is not None: #Controlnet
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
img += add
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
for block in self.single_blocks:
|
for block in self.single_blocks:
|
||||||
@ -123,7 +132,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
@ -138,5 +147,5 @@ class Flux(nn.Module):
|
|||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||||
|
|||||||
@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
|
|||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
@ -78,10 +78,9 @@ def apply_rotary_emb(
|
|||||||
xk_out = None
|
xk_out = None
|
||||||
if isinstance(freqs_cis, tuple):
|
if isinstance(freqs_cis, tuple):
|
||||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
xq_out = (xq * cos + rotate_half(xq) * sin)
|
||||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
if xk is not None:
|
if xk is not None:
|
||||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
xk_out = (xk * cos + rotate_half(xk) * sin)
|
||||||
else:
|
else:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||||
|
|||||||
321
comfy/ldm/hydit/controlnet.py
Normal file
321
comfy/ldm/hydit/controlnet.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch.utils import checkpoint
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||||
|
Mlp,
|
||||||
|
TimestepEmbedder,
|
||||||
|
PatchEmbed,
|
||||||
|
RMSNorm,
|
||||||
|
)
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from .poolers import AttentionPool
|
||||||
|
|
||||||
|
import comfy.latent_formats
|
||||||
|
from .models import HunYuanDiTBlock, calc_rope
|
||||||
|
|
||||||
|
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||||
|
|
||||||
|
|
||||||
|
class HunYuanControlNet(nn.Module):
|
||||||
|
"""
|
||||||
|
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||||
|
|
||||||
|
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||||
|
|
||||||
|
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args: argparse.Namespace
|
||||||
|
The arguments parsed by argparse.
|
||||||
|
input_size: tuple
|
||||||
|
The size of the input image.
|
||||||
|
patch_size: int
|
||||||
|
The size of the patch.
|
||||||
|
in_channels: int
|
||||||
|
The number of input channels.
|
||||||
|
hidden_size: int
|
||||||
|
The hidden size of the transformer backbone.
|
||||||
|
depth: int
|
||||||
|
The number of transformer blocks.
|
||||||
|
num_heads: int
|
||||||
|
The number of attention heads.
|
||||||
|
mlp_ratio: float
|
||||||
|
The ratio of the hidden size of the MLP in the transformer block.
|
||||||
|
log_fn: callable
|
||||||
|
The logging function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: tuple = 128,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 4,
|
||||||
|
hidden_size: int = 1408,
|
||||||
|
depth: int = 40,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.3637,
|
||||||
|
text_states_dim=1024,
|
||||||
|
text_states_dim_t5=2048,
|
||||||
|
text_len=77,
|
||||||
|
text_len_t5=256,
|
||||||
|
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
||||||
|
size_cond=False,
|
||||||
|
use_style_cond=False,
|
||||||
|
learn_sigma=True,
|
||||||
|
norm="layer",
|
||||||
|
log_fn: callable = print,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.log_fn = log_fn
|
||||||
|
self.depth = depth
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.text_states_dim = text_states_dim
|
||||||
|
self.text_states_dim_t5 = text_states_dim_t5
|
||||||
|
self.text_len = text_len
|
||||||
|
self.text_len_t5 = text_len_t5
|
||||||
|
self.size_cond = size_cond
|
||||||
|
self.use_style_cond = use_style_cond
|
||||||
|
self.norm = norm
|
||||||
|
self.dtype = dtype
|
||||||
|
self.latent_format = comfy.latent_formats.SDXL
|
||||||
|
|
||||||
|
self.mlp_t5 = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
self.text_states_dim_t5,
|
||||||
|
self.text_states_dim_t5 * 4,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(
|
||||||
|
self.text_states_dim_t5 * 4,
|
||||||
|
self.text_states_dim,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# learnable replace
|
||||||
|
self.text_embedding_padding = nn.Parameter(
|
||||||
|
torch.randn(
|
||||||
|
self.text_len + self.text_len_t5,
|
||||||
|
self.text_states_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention pooling
|
||||||
|
pooler_out_dim = 1024
|
||||||
|
self.pooler = AttentionPool(
|
||||||
|
self.text_len_t5,
|
||||||
|
self.text_states_dim_t5,
|
||||||
|
num_heads=8,
|
||||||
|
output_dim=pooler_out_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimension of the extra input vectors
|
||||||
|
self.extra_in_dim = pooler_out_dim
|
||||||
|
|
||||||
|
if self.size_cond:
|
||||||
|
# Image size and crop size conditions
|
||||||
|
self.extra_in_dim += 6 * 256
|
||||||
|
|
||||||
|
if self.use_style_cond:
|
||||||
|
# Here we use a default learned embedder layer for future extension.
|
||||||
|
self.style_embedder = nn.Embedding(
|
||||||
|
1, hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.extra_in_dim += hidden_size
|
||||||
|
|
||||||
|
# Text embedding for `add`
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
input_size,
|
||||||
|
patch_size,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.extra_embedder = nn.Sequential(
|
||||||
|
operations.Linear(
|
||||||
|
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(
|
||||||
|
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image embedding
|
||||||
|
num_patches = self.x_embedder.num_patches
|
||||||
|
|
||||||
|
# HUnYuanDiT Blocks
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
HunYuanDiTBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
c_emb_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
text_states_dim=self.text_states_dim,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
norm_type=self.norm,
|
||||||
|
skip=False,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(19)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input zero linear for the first block
|
||||||
|
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
# Output zero linear for the every block
|
||||||
|
self.after_proj_list = nn.ModuleList(
|
||||||
|
[
|
||||||
|
|
||||||
|
operations.Linear(
|
||||||
|
self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
for _ in range(len(self.blocks))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
hint,
|
||||||
|
timesteps,
|
||||||
|
context,#encoder_hidden_states=None,
|
||||||
|
text_embedding_mask=None,
|
||||||
|
encoder_hidden_states_t5=None,
|
||||||
|
text_embedding_mask_t5=None,
|
||||||
|
image_meta_size=None,
|
||||||
|
style=None,
|
||||||
|
return_dict=False,
|
||||||
|
**kwarg,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass of the encoder.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x: torch.Tensor
|
||||||
|
(B, D, H, W)
|
||||||
|
t: torch.Tensor
|
||||||
|
(B)
|
||||||
|
encoder_hidden_states: torch.Tensor
|
||||||
|
CLIP text embedding, (B, L_clip, D)
|
||||||
|
text_embedding_mask: torch.Tensor
|
||||||
|
CLIP text embedding mask, (B, L_clip)
|
||||||
|
encoder_hidden_states_t5: torch.Tensor
|
||||||
|
T5 text embedding, (B, L_t5, D)
|
||||||
|
text_embedding_mask_t5: torch.Tensor
|
||||||
|
T5 text embedding mask, (B, L_t5)
|
||||||
|
image_meta_size: torch.Tensor
|
||||||
|
(B, 6)
|
||||||
|
style: torch.Tensor
|
||||||
|
(B)
|
||||||
|
cos_cis_img: torch.Tensor
|
||||||
|
sin_cis_img: torch.Tensor
|
||||||
|
return_dict: bool
|
||||||
|
Whether to return a dictionary.
|
||||||
|
"""
|
||||||
|
condition = hint
|
||||||
|
if condition.shape[0] == 1:
|
||||||
|
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
||||||
|
|
||||||
|
text_states = context # 2,77,1024
|
||||||
|
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||||
|
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||||
|
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||||
|
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||||
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||||
|
|
||||||
|
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||||
|
|
||||||
|
text_states[:, -self.text_len :] = torch.where(
|
||||||
|
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
||||||
|
text_states[:, -self.text_len :],
|
||||||
|
padding[: self.text_len],
|
||||||
|
)
|
||||||
|
text_states_t5[:, -self.text_len_t5 :] = torch.where(
|
||||||
|
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
|
||||||
|
text_states_t5[:, -self.text_len_t5 :],
|
||||||
|
padding[self.text_len :],
|
||||||
|
)
|
||||||
|
|
||||||
|
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
||||||
|
|
||||||
|
# _, _, oh, ow = x.shape
|
||||||
|
# th, tw = oh // self.patch_size, ow // self.patch_size
|
||||||
|
|
||||||
|
# Get image RoPE embedding according to `reso`lution.
|
||||||
|
freqs_cis_img = calc_rope(
|
||||||
|
x, self.patch_size, self.hidden_size // self.num_heads
|
||||||
|
) # (cos_cis_img, sin_cis_img)
|
||||||
|
|
||||||
|
# ========================= Build time and image embedding =========================
|
||||||
|
t = self.t_embedder(timesteps, dtype=self.dtype)
|
||||||
|
x = self.x_embedder(x)
|
||||||
|
|
||||||
|
# ========================= Concatenate all extra vectors =========================
|
||||||
|
# Build text tokens with pooling
|
||||||
|
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||||
|
|
||||||
|
# Build image meta size tokens if applicable
|
||||||
|
# if image_meta_size is not None:
|
||||||
|
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
|
||||||
|
# if image_meta_size.dtype != self.dtype:
|
||||||
|
# image_meta_size = image_meta_size.half()
|
||||||
|
# image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||||
|
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||||
|
|
||||||
|
# Build style tokens
|
||||||
|
if style is not None:
|
||||||
|
style_embedding = self.style_embedder(style)
|
||||||
|
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||||
|
|
||||||
|
# Concatenate all extra vectors
|
||||||
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
|
# ========================= Deal with Condition =========================
|
||||||
|
condition = self.x_embedder(condition)
|
||||||
|
|
||||||
|
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||||
|
controls = []
|
||||||
|
x = x + self.before_proj(condition) # add condition
|
||||||
|
for layer, block in enumerate(self.blocks):
|
||||||
|
x = block(x, c, text_states, freqs_cis_img)
|
||||||
|
controls.append(self.after_proj_list[layer](x)) # zero linear for output
|
||||||
|
|
||||||
|
return {"output": controls}
|
||||||
@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
|
|||||||
sub_args = [start, stop, (th, tw)]
|
sub_args = [start, stop, (th, tw)]
|
||||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||||
|
rope = (rope[0].to(x), rope[1].to(x))
|
||||||
return rope
|
return rope
|
||||||
|
|
||||||
|
|
||||||
@ -91,6 +92,8 @@ class HunYuanDiTBlock(nn.Module):
|
|||||||
# Long Skip Connection
|
# Long Skip Connection
|
||||||
if self.skip_linear is not None:
|
if self.skip_linear is not None:
|
||||||
cat = torch.cat([x, skip], dim=-1)
|
cat = torch.cat([x, skip], dim=-1)
|
||||||
|
if cat.dtype != x.dtype:
|
||||||
|
cat = cat.to(x.dtype)
|
||||||
cat = self.skip_norm(cat)
|
cat = self.skip_norm(cat)
|
||||||
x = self.skip_linear(cat)
|
x = self.skip_linear(cat)
|
||||||
|
|
||||||
@ -362,6 +365,8 @@ class HunYuanDiT(nn.Module):
|
|||||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
controls = None
|
controls = None
|
||||||
|
if control:
|
||||||
|
controls = control.get("output", None)
|
||||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||||
skips = []
|
skips = []
|
||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
|
|||||||
@ -411,17 +411,17 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
logging.info("Using xformers cross attention")
|
logging.debug("Using xformers cross attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch cross attention")
|
logging.debug("Using pytorch cross attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
logging.info("Using split optimization for cross attention")
|
logging.debug("Using split optimization for cross attention")
|
||||||
optimized_attention = attention_split
|
optimized_attention = attention_split
|
||||||
else:
|
else:
|
||||||
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
logging.debug("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
|||||||
@ -268,13 +268,13 @@ class AttnBlock(nn.Module):
|
|||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
logging.info("Using xformers attention in VAE")
|
logging.debug("Using xformers attention in VAE")
|
||||||
self.optimized_attention = xformers_attention
|
self.optimized_attention = xformers_attention
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch attention in VAE")
|
logging.debug("Using pytorch attention in VAE")
|
||||||
self.optimized_attention = pytorch_attention
|
self.optimized_attention = pytorch_attention
|
||||||
else:
|
else:
|
||||||
logging.info("Using split attention in VAE")
|
logging.debug("Using split attention in VAE")
|
||||||
self.optimized_attention = normal_attention
|
self.optimized_attention = normal_attention
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import model_base
|
from . import model_base
|
||||||
@ -216,11 +234,17 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
for k in sdk: #OneTrainer SD3 lora
|
for k in sdk:
|
||||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
key_map[lora_key] = k
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
k = "clip_g.transformer.text_projection.weight"
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
@ -243,6 +267,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = k
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||||
|
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
|
|||||||
@ -1,7 +1,26 @@
|
|||||||
import logging
|
"""
|
||||||
from enum import Enum
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TypeVar, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import conds
|
from . import conds
|
||||||
@ -11,15 +30,16 @@ from . import ops
|
|||||||
from . import utils
|
from . import utils
|
||||||
from .ldm.audio.dit import AudioDiffusionTransformer
|
from .ldm.audio.dit import AudioDiffusionTransformer
|
||||||
from .ldm.audio.embedders import NumberConditioner
|
from .ldm.audio.embedders import NumberConditioner
|
||||||
|
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
|
||||||
from .ldm.cascade.stage_b import StageB
|
from .ldm.cascade.stage_b import StageB
|
||||||
from .ldm.cascade.stage_c import StageC
|
from .ldm.cascade.stage_c import StageC
|
||||||
|
from .ldm.flux import model as flux_model
|
||||||
|
from .ldm.hydit.models import HunYuanDiT
|
||||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
|
|
||||||
from .ldm.hydit.models import HunYuanDiT
|
|
||||||
from .ldm.flux import model as flux_model
|
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
@ -68,26 +88,33 @@ def model_sampling(model_config, model_type):
|
|||||||
return ModelSampling(model_config)
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|
||||||
|
TModule = TypeVar('TModule', bound=torch.nn.Module)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
def __init__(self, model_config, model_type=ModelType.EPS, device: torch.device = None, unet_model: Type[TModule] = UNetModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||||
|
self.device: torch.device = device
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if self.manual_cast_dtype is not None:
|
if model_config.custom_operations is None:
|
||||||
operations = ops.manual_cast
|
if self.manual_cast_dtype is not None:
|
||||||
|
operations = ops.manual_cast
|
||||||
|
else:
|
||||||
|
operations = ops.disable_weight_init
|
||||||
else:
|
else:
|
||||||
operations = ops.disable_weight_init
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
if model_management.force_channels_last():
|
if model_management.force_channels_last():
|
||||||
# todo: ???
|
# todo: ???
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
logging.debug("using channels last mode for diffusion model")
|
logging.debug("using channels last mode for diffusion model")
|
||||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
@ -96,7 +123,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
|
|
||||||
self.concat_keys = ()
|
self.concat_keys = ()
|
||||||
logging.info("model_type {}".format(model_type.name))
|
logging.debug("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
|
||||||
@ -669,6 +696,7 @@ class StableAudio1(BaseModel):
|
|||||||
sd["{}{}".format(k, l)] = s[l]
|
sd["{}{}".format(k, l)] = s[l]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
class HunyuanDiT(BaseModel):
|
class HunyuanDiT(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT)
|
super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT)
|
||||||
@ -701,6 +729,7 @@ class HunyuanDiT(BaseModel):
|
|||||||
out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Flux(BaseModel):
|
class Flux(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux)
|
super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux)
|
||||||
|
|||||||
@ -136,8 +136,8 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["hidden_size"] = 3072
|
dit_config["hidden_size"] = 3072
|
||||||
dit_config["mlp_ratio"] = 4.0
|
dit_config["mlp_ratio"] = 4.0
|
||||||
dit_config["num_heads"] = 24
|
dit_config["num_heads"] = 24
|
||||||
dit_config["depth"] = 19
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = 38
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["axes_dim"] = [16, 56, 56]
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
dit_config["theta"] = 10000
|
dit_config["theta"] = 10000
|
||||||
dit_config["qkv_bias"] = True
|
dit_config["qkv_bias"] = True
|
||||||
@ -494,7 +494,12 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
|
|
||||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||||
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||||
|
sd_map = utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||||
|
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
@ -520,7 +525,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
old_weight = out_sd.get(t[0], None)
|
old_weight = out_sd.get(t[0], None)
|
||||||
if old_weight is None:
|
if old_weight is None:
|
||||||
old_weight = torch.empty_like(weight)
|
old_weight = torch.empty_like(weight)
|
||||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
||||||
|
exp = list(weight.shape)
|
||||||
|
exp[offset[0]] = offset[1] + offset[2]
|
||||||
|
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
||||||
|
new[:old_weight.shape[0]] = old_weight
|
||||||
|
old_weight = new
|
||||||
|
|
||||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -571,7 +571,7 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
|
|||||||
# fix path representation
|
# fix path representation
|
||||||
local_files = set(f.replace("\\", "/") for f in local_files)
|
local_files = set(f.replace("\\", "/") for f in local_files)
|
||||||
# remove .huggingface
|
# remove .huggingface
|
||||||
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface"))
|
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
|
||||||
# local_files.issubsetof(repo_files)
|
# local_files.issubsetof(repo_files)
|
||||||
if len(local_files) > 0 and local_files.issubset(repo_files):
|
if len(local_files) > 0 and local_files.issubset(repo_files):
|
||||||
local_dirs_snapshots.append(str(local_path))
|
local_dirs_snapshots.append(str(local_path))
|
||||||
|
|||||||
@ -1,3 +1,20 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -143,10 +160,10 @@ if torch.cuda.is_available() and hasattr(torch.version, "hip") and torch.version
|
|||||||
logging.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
logging.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
||||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.debug("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("pytorch version: {}".format(torch.version.__version__))
|
logging.debug("pytorch version: {}".format(torch.version.__version__))
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -171,7 +188,7 @@ else:
|
|||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
XFORMERS_VERSION = xformers.version.__version__
|
XFORMERS_VERSION = xformers.version.__version__
|
||||||
logging.info("xformers version: {}".format(XFORMERS_VERSION))
|
logging.debug("xformers version: {}".format(XFORMERS_VERSION))
|
||||||
if XFORMERS_VERSION.startswith("0.0.18"):
|
if XFORMERS_VERSION.startswith("0.0.18"):
|
||||||
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||||
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
||||||
@ -263,12 +280,12 @@ if cpu_state != CPUState.GPU:
|
|||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
vram_state = VRAMState.SHARED
|
vram_state = VRAMState.SHARED
|
||||||
|
|
||||||
logging.info(f"Set vram state to: {vram_state.name}")
|
logging.debug(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||||
|
|
||||||
if DISABLE_SMART_MEMORY:
|
if DISABLE_SMART_MEMORY:
|
||||||
logging.info("Disabling smart memory management")
|
logging.debug("Disabling smart memory management")
|
||||||
|
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
@ -288,7 +305,7 @@ def get_torch_device_name(device):
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
logging.debug("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
|
||||||
@ -315,9 +332,12 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
|
def model_offloaded_memory(self):
|
||||||
|
return self.model.model_size() - self.model.loaded_size()
|
||||||
|
|
||||||
def model_memory_required(self, device):
|
def model_memory_required(self, device):
|
||||||
if device == self.model.current_device:
|
if device == self.model.current_loaded_device():
|
||||||
return 0
|
return self.model_offloaded_memory()
|
||||||
else:
|
else:
|
||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
@ -329,15 +349,21 @@ class LoadedModel:
|
|||||||
|
|
||||||
load_weights = not self.weights_loaded
|
load_weights = not self.weights_loaded
|
||||||
|
|
||||||
try:
|
if self.model.loaded_size() > 0:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
use_more_vram = lowvram_model_memory
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
if use_more_vram == 0:
|
||||||
else:
|
use_more_vram = 1e32
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
self.model_use_more_vram(use_more_vram)
|
||||||
except Exception as e:
|
else:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
try:
|
||||||
self.model_unload()
|
if lowvram_model_memory > 0 and load_weights:
|
||||||
raise e
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
|
else:
|
||||||
|
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||||
|
except Exception as e:
|
||||||
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
|
self.model_unload()
|
||||||
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||||
@ -346,15 +372,24 @@ class LoadedModel:
|
|||||||
return self.real_model
|
return self.real_model
|
||||||
|
|
||||||
def should_reload_model(self, force_patch_weights=False):
|
def should_reload_model(self, force_patch_weights=False):
|
||||||
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def model_unload(self, unpatch_weights=True):
|
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||||
|
if memory_to_free is not None:
|
||||||
|
if memory_to_free < self.model.loaded_size():
|
||||||
|
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||||
|
if freed >= memory_to_free:
|
||||||
|
return False
|
||||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
|
return True
|
||||||
|
|
||||||
|
def model_use_more_vram(self, extra_memory):
|
||||||
|
return self.model.partially_load(self.device, extra_memory)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.model is other.model
|
return self.model is other.model
|
||||||
@ -366,39 +401,59 @@ class LoadedModel:
|
|||||||
return f"<LoadedModel>"
|
return f"<LoadedModel>"
|
||||||
|
|
||||||
|
|
||||||
|
def use_more_memory(extra_memory, loaded_models, device):
|
||||||
|
for m in loaded_models:
|
||||||
|
if m.device == device:
|
||||||
|
extra_memory -= m.model_use_more_vram(extra_memory)
|
||||||
|
if extra_memory <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def offloaded_memory(loaded_models, device):
|
||||||
|
offloaded_mem = 0
|
||||||
|
for m in loaded_models:
|
||||||
|
if m.device == device:
|
||||||
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
|
return offloaded_mem
|
||||||
|
|
||||||
|
|
||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
return (1024 * 1024 * 1024) * 1.2
|
||||||
|
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
to_unload = []
|
return _unload_model_clones(model, unload_weights_only, force_unload)
|
||||||
for i in range(len(current_loaded_models)):
|
|
||||||
if model.is_clone(current_loaded_models[i].model):
|
|
||||||
to_unload = [i] + to_unload
|
|
||||||
|
|
||||||
if len(to_unload) == 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
same_weights = 0
|
def _unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
||||||
for i in to_unload:
|
to_unload = []
|
||||||
if model.clone_has_same_weights(current_loaded_models[i].model):
|
for i in range(len(current_loaded_models)):
|
||||||
same_weights += 1
|
if model.is_clone(current_loaded_models[i].model):
|
||||||
|
to_unload = [i] + to_unload
|
||||||
|
|
||||||
if same_weights == len(to_unload):
|
if len(to_unload) == 0:
|
||||||
unload_weight = False
|
return True
|
||||||
else:
|
|
||||||
unload_weight = True
|
|
||||||
|
|
||||||
if not force_unload:
|
same_weights = 0
|
||||||
if unload_weights_only and unload_weight == False:
|
for i in to_unload:
|
||||||
return None
|
if model.clone_has_same_weights(current_loaded_models[i].model):
|
||||||
|
same_weights += 1
|
||||||
|
|
||||||
for i in to_unload:
|
if same_weights == len(to_unload):
|
||||||
logging.debug("unload clone {}{}".format(i, unload_weight))
|
unload_weight = False
|
||||||
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
else:
|
||||||
|
unload_weight = True
|
||||||
|
|
||||||
return unload_weight
|
if not force_unload:
|
||||||
|
if unload_weights_only and unload_weight == False:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for i in to_unload:
|
||||||
|
logging.debug("unload clone {} {}".format(i, unload_weight))
|
||||||
|
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
|
||||||
|
|
||||||
|
return unload_weight
|
||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Free Memory")
|
@tracer.start_as_current_span("Free Memory")
|
||||||
@ -406,126 +461,158 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
|||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
span.set_attribute("memory_required", memory_required)
|
span.set_attribute("memory_required", memory_required)
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
unloaded_models: List[LoadedModel] = []
|
unloaded_models = _free_memory(memory_required, device, keep_loaded)
|
||||||
can_unload = []
|
|
||||||
|
|
||||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
|
||||||
shift_model = current_loaded_models[i]
|
|
||||||
if shift_model.device == device:
|
|
||||||
if shift_model not in keep_loaded:
|
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
|
||||||
shift_model.currently_used = False
|
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
|
||||||
i = x[-1]
|
|
||||||
if not DISABLE_SMART_MEMORY:
|
|
||||||
if get_free_memory(device) > memory_required:
|
|
||||||
break
|
|
||||||
current_loaded_models[i].model_unload()
|
|
||||||
unloaded_models.append(i)
|
|
||||||
|
|
||||||
for i in sorted(unloaded_models, reverse=True):
|
|
||||||
current_loaded_models.pop(i)
|
|
||||||
|
|
||||||
if len(unloaded_models) > 0:
|
|
||||||
soft_empty_cache()
|
|
||||||
else:
|
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
|
||||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
|
||||||
if mem_free_torch > mem_free_total * 0.25:
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
|
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
|
|
||||||
|
def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||||
|
unloaded_model = []
|
||||||
|
can_unload = []
|
||||||
|
unloaded_models = []
|
||||||
|
|
||||||
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
|
shift_model = current_loaded_models[i]
|
||||||
|
if shift_model.device == device:
|
||||||
|
if shift_model not in keep_loaded:
|
||||||
|
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
|
shift_model.currently_used = False
|
||||||
|
|
||||||
|
for x in sorted(can_unload):
|
||||||
|
i = x[-1]
|
||||||
|
memory_to_free = None
|
||||||
|
if not DISABLE_SMART_MEMORY:
|
||||||
|
free_mem = get_free_memory(device)
|
||||||
|
if free_mem > memory_required:
|
||||||
|
break
|
||||||
|
memory_to_free = memory_required - free_mem
|
||||||
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
if current_loaded_models[i].model_unload(memory_to_free):
|
||||||
|
unloaded_model.append(i)
|
||||||
|
|
||||||
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
|
|
||||||
|
if len(unloaded_model) > 0:
|
||||||
|
soft_empty_cache()
|
||||||
|
else:
|
||||||
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
|
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||||
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
|
soft_empty_cache()
|
||||||
|
return unloaded_models
|
||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Load Models GPU")
|
@tracer.start_as_current_span("Load Models GPU")
|
||||||
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None) -> None:
|
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
||||||
global vram_state
|
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
if memory_required != 0:
|
if memory_required != 0:
|
||||||
span.set_attribute("memory_required", memory_required)
|
span.set_attribute("memory_required", memory_required)
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
inference_memory = minimum_inference_memory()
|
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||||
extra_mem = max(inference_memory, memory_required)
|
to_load = list(map(str, models))
|
||||||
if minimum_memory_required is None:
|
span.set_attribute("models", to_load)
|
||||||
minimum_memory_required = extra_mem
|
logging.info(f"Loaded {to_load}")
|
||||||
else:
|
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required)
|
|
||||||
|
|
||||||
models = set(models)
|
|
||||||
models_to_load = []
|
|
||||||
models_already_loaded = []
|
|
||||||
for x in models:
|
|
||||||
loaded_model = LoadedModel(x)
|
|
||||||
loaded = None
|
|
||||||
|
|
||||||
try:
|
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
||||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
global vram_state
|
||||||
except ValueError:
|
|
||||||
loaded_model_index = None
|
|
||||||
|
|
||||||
if loaded_model_index is not None:
|
inference_memory = minimum_inference_memory()
|
||||||
loaded = current_loaded_models[loaded_model_index]
|
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
||||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
|
if minimum_memory_required is None:
|
||||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
minimum_memory_required = extra_mem
|
||||||
loaded = None
|
else:
|
||||||
else:
|
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
||||||
loaded.currently_used = True
|
|
||||||
models_already_loaded.append(loaded)
|
models = set(models)
|
||||||
if loaded is None:
|
|
||||||
models_to_load.append(loaded_model)
|
models_to_load = []
|
||||||
|
models_already_loaded = []
|
||||||
|
models_freed = []
|
||||||
|
for x in models:
|
||||||
|
loaded_model = LoadedModel(x)
|
||||||
|
loaded = None
|
||||||
|
|
||||||
models_freed: List[LoadedModel] = []
|
|
||||||
try:
|
try:
|
||||||
if len(models_to_load) == 0:
|
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
except:
|
||||||
for d in devs:
|
loaded_model_index = None
|
||||||
if d != torch.device("cpu"):
|
|
||||||
models_freed += free_memory(extra_mem, d, models_already_loaded)
|
|
||||||
return
|
|
||||||
|
|
||||||
total_memory_required = {}
|
if loaded_model_index is not None:
|
||||||
for loaded_model in models_to_load:
|
loaded = current_loaded_models[loaded_model_index]
|
||||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||||
|
loaded = None
|
||||||
|
else:
|
||||||
|
loaded.currently_used = True
|
||||||
|
models_already_loaded.append(loaded)
|
||||||
|
|
||||||
for device in total_memory_required:
|
if loaded is None:
|
||||||
if device != torch.device("cpu"):
|
models_to_load.append(loaded_model)
|
||||||
# todo: where does 1.3 come from?
|
|
||||||
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
if len(models_to_load) == 0:
|
||||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
if weights_unloaded is not None:
|
for d in devs:
|
||||||
loaded_model.weights_loaded = not weights_unloaded
|
if d != torch.device("cpu"):
|
||||||
|
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||||
for loaded_model in models_to_load:
|
free_mem = get_free_memory(d)
|
||||||
model = loaded_model.model
|
if free_mem < minimum_memory_required:
|
||||||
torch_dev = model.load_device
|
models_to_load = free_memory(minimum_memory_required, d)
|
||||||
if is_device_cpu(torch_dev):
|
models_freed += models_to_load
|
||||||
vram_set_state = VRAMState.DISABLED
|
|
||||||
else:
|
else:
|
||||||
vram_set_state = vram_state
|
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||||
lowvram_model_memory = 0
|
if len(models_to_load) == 0:
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
|
||||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
|
||||||
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
|
|
||||||
lowvram_model_memory = 0
|
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
|
||||||
|
|
||||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
|
||||||
current_loaded_models.insert(0, loaded_model)
|
|
||||||
return
|
return
|
||||||
finally:
|
|
||||||
span.set_attribute("models", list(map(str, models)))
|
total_memory_required = {}
|
||||||
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
for loaded_model in models_to_load:
|
||||||
span.set_attribute("models_freed", list(map(str, models_freed)))
|
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) # unload clones where the weights are different
|
||||||
logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
|
for loaded_model in models_already_loaded:
|
||||||
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
|
for loaded_model in models_to_load:
|
||||||
|
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||||
|
if weights_unloaded is not None:
|
||||||
|
loaded_model.weights_loaded = not weights_unloaded
|
||||||
|
|
||||||
|
for device in total_memory_required:
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||||
|
|
||||||
|
for loaded_model in models_to_load:
|
||||||
|
model = loaded_model.model
|
||||||
|
torch_dev = model.load_device
|
||||||
|
if is_device_cpu(torch_dev):
|
||||||
|
vram_set_state = VRAMState.DISABLED
|
||||||
|
else:
|
||||||
|
vram_set_state = vram_state
|
||||||
|
lowvram_model_memory = 0
|
||||||
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||||
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
|
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||||
|
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
|
||||||
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
|
lowvram_model_memory = 64 * 1024 * 1024
|
||||||
|
|
||||||
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
|
current_loaded_models.insert(0, loaded_model)
|
||||||
|
|
||||||
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
|
for d in devs:
|
||||||
|
if d != torch.device("cpu"):
|
||||||
|
free_mem = get_free_memory(d)
|
||||||
|
if free_mem > minimum_memory_required:
|
||||||
|
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||||
|
|
||||||
|
span = get_current_span()
|
||||||
|
span.set_attribute("models_to_load", list(map(str, models_to_load)))
|
||||||
|
span.set_attribute("models_freed", list(map(str, models_freed)))
|
||||||
|
|
||||||
|
|
||||||
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
|
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
|
||||||
@ -605,6 +692,7 @@ def unet_initial_load_device(parameters, dtype):
|
|||||||
def maximum_vram_for_weights(device=None):
|
def maximum_vram_for_weights(device=None):
|
||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
@ -629,12 +717,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
|
|||||||
if model_params * 2 > free_model_memory:
|
if model_params * 2 > free_model_memory:
|
||||||
return fp8_dtype
|
return fp8_dtype
|
||||||
|
|
||||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
for dt in supported_dtypes:
|
||||||
if torch.float16 in supported_dtypes:
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||||
return torch.float16
|
if torch.float16 in supported_dtypes:
|
||||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
return torch.float16
|
||||||
if torch.bfloat16 in supported_dtypes:
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
||||||
return torch.bfloat16
|
if torch.bfloat16 in supported_dtypes:
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
|
for dt in supported_dtypes:
|
||||||
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||||
|
if torch.float16 in supported_dtypes:
|
||||||
|
return torch.float16
|
||||||
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||||
|
if torch.bfloat16 in supported_dtypes:
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
@ -651,13 +749,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if fp16_supported and torch.float16 in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
return torch.float16
|
if dt == torch.float16 and fp16_supported:
|
||||||
|
return torch.float16
|
||||||
|
if dt == torch.bfloat16 and bf16_supported:
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
return torch.float32
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
@ -679,6 +777,21 @@ def text_encoder_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||||
|
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
if is_device_mps(load_device):
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
mem_l = get_free_memory(load_device)
|
||||||
|
mem_o = get_free_memory(offload_device)
|
||||||
|
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||||
|
return load_device
|
||||||
|
else:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
|
||||||
def text_encoder_dtype(device=None):
|
def text_encoder_dtype(device=None):
|
||||||
if args.fp8_e4m3fn_text_enc:
|
if args.fp8_e4m3fn_text_enc:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
|
|||||||
@ -1,8 +1,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Protocol, Optional, Any
|
import dataclasses
|
||||||
|
from typing import Protocol, Optional, TypeVar, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn
|
||||||
|
|
||||||
|
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class DeviceSettable(Protocol):
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
...
|
||||||
|
|
||||||
|
@device.setter
|
||||||
|
def device(self, value: torch.device):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class ModelManageable(Protocol):
|
class ModelManageable(Protocol):
|
||||||
@ -22,32 +37,92 @@ class ModelManageable(Protocol):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device:
|
||||||
...
|
return next(self.model.parameters()).device
|
||||||
|
|
||||||
def is_clone(self, other: Any) -> bool:
|
def is_clone(self, other: ModelManageableT) -> bool:
|
||||||
...
|
return other.model is self.model
|
||||||
|
|
||||||
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
def clone_has_same_weights(self, clone: ModelManageableT) -> bool:
|
||||||
...
|
return clone.model is self.model
|
||||||
|
|
||||||
def model_size(self) -> int:
|
def model_size(self) -> int:
|
||||||
...
|
from .model_management import module_size
|
||||||
|
return module_size(self.model)
|
||||||
|
|
||||||
def model_patches_to(self, arg: torch.device | torch.dtype):
|
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||||
...
|
pass
|
||||||
|
|
||||||
def model_dtype(self) -> torch.dtype:
|
def model_dtype(self) -> torch.dtype:
|
||||||
...
|
return next(self.model.parameters()).dtype
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
|
self.patch_model(device_to=device_to, patch_weights=False)
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Loads the model to the device
|
||||||
|
:param device_to: the device to move the model weights to
|
||||||
|
:param patch_weights: True if the patch's weights should also be moved
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
|
"""
|
||||||
|
Unloads the model by moving it to the offload device
|
||||||
|
:param offload_device:
|
||||||
|
:param unpatch_weights:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
def lowvram_patch_counter(self) -> int:
|
||||||
...
|
return 0
|
||||||
|
|
||||||
|
def partially_load(self, device_to: torch.device, extra_memory=0) -> int:
|
||||||
|
self.patch_model(device_to=device_to)
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
|
def partially_unload(self, device_to: torch.device, extra_memory=0) -> int:
|
||||||
|
self.unpatch_model(device_to)
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
|
def memory_required(self, input_shape) -> int:
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
|
||||||
|
if isinstance(self.model, BaseModel):
|
||||||
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
else:
|
||||||
|
# todo: why isn't this true?
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
|
def loaded_size(self) -> int:
|
||||||
|
if self.current_loaded_device() == self.load_device:
|
||||||
|
return self.model_size()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def current_loaded_device(self) -> torch.device:
|
||||||
|
return self.current_device
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MemoryMeasurements:
|
||||||
|
model: torch.nn.Module | DeviceSettable
|
||||||
|
model_loaded_weight_memory: int = 0
|
||||||
|
lowvram_patch_counter: int = 0
|
||||||
|
model_lowvram: bool = False
|
||||||
|
_device: torch.device | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lowvram_patch_counter(self) -> int:
|
def device(self) -> torch.device:
|
||||||
...
|
if isinstance(self.model, DeviceSettable):
|
||||||
|
return self.model.device
|
||||||
|
else:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
@device.setter
|
||||||
|
def device(self, value: torch.device):
|
||||||
|
if isinstance(self.model, DeviceSettable):
|
||||||
|
self.model.device = value
|
||||||
|
self._device = value
|
||||||
|
|||||||
@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@ -5,10 +23,12 @@ import uuid
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn
|
||||||
|
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import utils
|
from . import utils
|
||||||
from .model_management_types import ModelManageable
|
from .model_base import BaseModel
|
||||||
|
from .model_management_types import ModelManageable, MemoryMeasurements
|
||||||
from .types import UnetWrapperFunction
|
from .types import UnetWrapperFunction
|
||||||
|
|
||||||
|
|
||||||
@ -69,10 +89,27 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
|
||||||
|
def wipe_lowvram_weight(m):
|
||||||
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
|
del m.prev_comfy_cast_weights
|
||||||
|
m.weight_function = None
|
||||||
|
m.bias_function = None
|
||||||
|
|
||||||
|
|
||||||
|
class LowVramPatch:
|
||||||
|
def __init__(self, key, model_patcher):
|
||||||
|
self.key = key
|
||||||
|
self.model_patcher = model_patcher
|
||||||
|
|
||||||
|
def __call__(self, weight):
|
||||||
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher(ModelManageable):
|
class ModelPatcher(ModelManageable):
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model: torch.nn.Module = model
|
||||||
self.patches = {}
|
self.patches = {}
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.object_patches = {}
|
self.object_patches = {}
|
||||||
@ -81,25 +118,21 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.model_size()
|
self.model_size()
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
self._current_device: torch.device
|
|
||||||
if current_device is None:
|
|
||||||
self._current_device = self.offload_device
|
|
||||||
else:
|
|
||||||
self._current_device = current_device
|
|
||||||
|
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
self.model_lowvram = False
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
self.ckpt_name = ckpt_name
|
self.ckpt_name = ckpt_name
|
||||||
self._lowvram_patch_counter = 0
|
self._memory_measurements = MemoryMeasurements(self.model)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lowvram_patch_counter(self):
|
def model_device(self) -> torch.device:
|
||||||
return self._lowvram_patch_counter
|
return self._memory_measurements.device
|
||||||
|
|
||||||
@lowvram_patch_counter.setter
|
@model_device.setter
|
||||||
def lowvram_patch_counter(self, value: int):
|
def model_device(self, value: torch.device):
|
||||||
self._lowvram_patch_counter = value
|
self._memory_measurements.device = value
|
||||||
|
|
||||||
|
def lowvram_patch_counter(self):
|
||||||
|
return self._memory_measurements.lowvram_patch_counter
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
@ -107,8 +140,12 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.size = model_management.module_size(self.model)
|
self.size = model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def loaded_size(self):
|
||||||
|
return self._memory_measurements.model_loaded_weight_memory
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||||
|
n._memory_measurements = self._memory_measurements
|
||||||
n.ckpt_name = self.ckpt_name
|
n.ckpt_name = self.ckpt_name
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
@ -122,11 +159,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
def is_clone(self, other):
|
def is_clone(self, other):
|
||||||
if hasattr(other, 'model') and self.model is other.model:
|
return hasattr(other, 'model') and self.model is other.model
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def clone_has_same_weights(self, clone):
|
def clone_has_same_weights(self, clone: "ModelPatcher"):
|
||||||
if not self.is_clone(clone):
|
if not self.is_clone(clone):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -139,7 +174,8 @@ class ModelPatcher(ModelManageable):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape) -> int:
|
||||||
|
assert isinstance(self.model, BaseModel)
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
@ -281,16 +317,16 @@ class ModelPatcher(ModelManageable):
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_weight_to_device(self, key, device_to=None):
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight = utils.get_attr(self.model, key)
|
weight = utils.get_attr(self.model, key)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
@ -319,31 +355,25 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self._current_device = device_to
|
self.model_device = device_to
|
||||||
|
self._memory_measurements.model_loaded_weight_memory = self.model_size()
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
self.patch_model(device_to, patch_weights=False)
|
|
||||||
|
|
||||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
|
|
||||||
|
|
||||||
class LowVramPatch:
|
|
||||||
def __init__(self, key, model_patcher):
|
|
||||||
self.key = key
|
|
||||||
self.model_patcher = model_patcher
|
|
||||||
|
|
||||||
def __call__(self, weight):
|
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
|
||||||
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
|
lowvram_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
|
||||||
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = model_management.module_size(m)
|
module_mem = model_management.module_size(m)
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
|
lowvram_counter += 1
|
||||||
|
if m.comfy_cast_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
@ -365,15 +395,39 @@ class ModelPatcher(ModelManageable):
|
|||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
else:
|
else:
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
if m.comfy_cast_weights:
|
||||||
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
self.patch_weight_to_device(weight_key, device_to)
|
|
||||||
self.patch_weight_to_device(bias_key, device_to)
|
|
||||||
m.to(device_to)
|
|
||||||
mem_counter += model_management.module_size(m)
|
mem_counter += model_management.module_size(m)
|
||||||
|
param = list(m.parameters())
|
||||||
|
if len(param) > 0:
|
||||||
|
weight = param[0]
|
||||||
|
if weight.device == device_to:
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight_to = None
|
||||||
|
if full_load: # TODO
|
||||||
|
weight_to = device_to
|
||||||
|
self.patch_weight_to_device(weight_key, device_to=weight_to) # TODO: speed this up without OOM
|
||||||
|
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
||||||
|
m.to(device_to)
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|
||||||
self.model_lowvram = True
|
if lowvram_counter > 0:
|
||||||
self.lowvram_patch_counter = patch_counter
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
|
self._memory_measurements.model_lowvram = True
|
||||||
|
else:
|
||||||
|
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||||
|
self._memory_measurements.model_lowvram = False
|
||||||
|
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||||
|
self._memory_measurements.model_loaded_weight_memory = mem_counter
|
||||||
|
self.model_device = device_to
|
||||||
|
|
||||||
|
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
|
self.patch_model(device_to, patch_weights=False)
|
||||||
|
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
@ -548,31 +602,28 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
if self.model_lowvram:
|
if self._memory_measurements.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if hasattr(m, "prev_comfy_cast_weights"):
|
wipe_lowvram_weight(m)
|
||||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
|
||||||
del m.prev_comfy_cast_weights
|
|
||||||
m.weight_function = None
|
|
||||||
m.bias_function = None
|
|
||||||
|
|
||||||
self.model_lowvram = False
|
self._memory_measurements.model_lowvram = False
|
||||||
self.lowvram_patch_counter = 0
|
self._memory_measurements.lowvram_patch_counter = 0
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
if self.weight_inplace_update:
|
for k in keys:
|
||||||
for k in keys:
|
bk = self.backup[k]
|
||||||
utils.copy_to_param(self.model, k, self.backup[k])
|
if bk.inplace_update:
|
||||||
else:
|
utils.copy_to_param(self.model, k, bk.weight)
|
||||||
for k in keys:
|
else:
|
||||||
utils.set_attr_param(self.model, k, self.backup[k])
|
utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self._current_device = value = device_to
|
self.model_device = device_to
|
||||||
|
self._memory_measurements.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
keys = list(self.object_patches_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
@ -580,9 +631,66 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
|
memory_freed = 0
|
||||||
|
patch_counter = 0
|
||||||
|
|
||||||
|
for n, m in list(self.model.named_modules())[::-1]:
|
||||||
|
if memory_to_free < memory_freed:
|
||||||
|
break
|
||||||
|
|
||||||
|
shift_lowvram = False
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
module_mem = model_management.module_size(m)
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
|
if m.weight is not None and m.weight.device != device_to:
|
||||||
|
for key in [weight_key, bias_key]:
|
||||||
|
bk = self.backup.get(key, None)
|
||||||
|
if bk is not None:
|
||||||
|
if bk.inplace_update:
|
||||||
|
utils.copy_to_param(self.model, key, bk.weight)
|
||||||
|
else:
|
||||||
|
utils.set_attr_param(self.model, key, bk.weight)
|
||||||
|
self.backup.pop(key)
|
||||||
|
|
||||||
|
m.to(device_to)
|
||||||
|
if weight_key in self.patches:
|
||||||
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
if bias_key in self.patches:
|
||||||
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
memory_freed += module_mem
|
||||||
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
|
self._memory_measurements.model_lowvram = True
|
||||||
|
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||||
|
self._memory_measurements.model_loaded_weight_memory -= memory_freed
|
||||||
|
return memory_freed
|
||||||
|
|
||||||
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
|
self.unpatch_model(unpatch_weights=False)
|
||||||
|
self.patch_model(patch_weights=False)
|
||||||
|
full_load = False
|
||||||
|
if not self._memory_measurements.model_lowvram:
|
||||||
|
return 0
|
||||||
|
if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
|
full_load = True
|
||||||
|
current_used = self._memory_measurements.model_loaded_weight_memory
|
||||||
|
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
|
return self._memory_measurements.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
|
def current_loaded_device(self):
|
||||||
|
return self.model_device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device:
|
||||||
return self._current_device
|
return self.current_loaded_device()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.ckpt_name is not None:
|
if self.ckpt_name is not None:
|
||||||
|
|||||||
@ -823,14 +823,14 @@ class UNETLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name, weight_dtype):
|
def load_unet(self, unet_name, weight_dtype):
|
||||||
dtype = None
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
dtype = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
dtype = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
|
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
|
||||||
model = sd.load_unet(unet_path, dtype=dtype)
|
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
|
|||||||
@ -91,7 +91,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
|
|
||||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
|
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
|
||||||
for (duration, module_name, success, new_nodes) in sorted(timings):
|
for (duration, module_name, success, new_nodes) in sorted(timings):
|
||||||
logging.info(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
||||||
if raise_on_failure and len(exceptions) > 0:
|
if raise_on_failure and len(exceptions) > 0:
|
||||||
try:
|
try:
|
||||||
raise ExceptionGroup("Node import failed", exceptions)
|
raise ExceptionGroup("Node import failed", exceptions)
|
||||||
|
|||||||
83
comfy/sd.py
83
comfy/sd.py
@ -20,21 +20,19 @@ from . import model_sampling
|
|||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
from . import utils
|
from . import utils
|
||||||
from .model_management import load_models_gpu
|
|
||||||
|
|
||||||
from .text_encoders import sd2_clip
|
|
||||||
from .text_encoders import sd3_clip
|
|
||||||
from .text_encoders import hydit
|
|
||||||
from .text_encoders import sa_t5
|
|
||||||
from .text_encoders import aura_t5
|
|
||||||
from .text_encoders import flux
|
|
||||||
|
|
||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
from .ldm.cascade.stage_a import StageA
|
from .ldm.cascade.stage_a import StageA
|
||||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
|
from .model_management import load_models_gpu
|
||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
from .taesd import taesd
|
from .taesd import taesd
|
||||||
|
from .text_encoders import aura_t5
|
||||||
|
from .text_encoders import flux
|
||||||
|
from .text_encoders import hydit
|
||||||
|
from .text_encoders import sa_t5
|
||||||
|
from .text_encoders import sd2_clip
|
||||||
|
from .text_encoders import sd3_clip
|
||||||
|
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||||
@ -68,7 +66,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target: CLIPTarget=None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None=None):
|
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0):
|
||||||
if tokenizer_data is None:
|
if tokenizer_data is None:
|
||||||
tokenizer_data = dict()
|
tokenizer_data = dict()
|
||||||
if no_init:
|
if no_init:
|
||||||
@ -79,9 +77,9 @@ class CLIP:
|
|||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
|
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||||
if "textmodel_json_config" not in params and textmodel_json_config is not None:
|
if "textmodel_json_config" not in params and textmodel_json_config is not None:
|
||||||
params['textmodel_json_config'] = textmodel_json_config
|
params['textmodel_json_config'] = textmodel_json_config
|
||||||
|
|
||||||
@ -90,11 +88,16 @@ class CLIP:
|
|||||||
for dt in self.cond_stage_model.dtypes:
|
for dt in self.cond_stage_model.dtypes:
|
||||||
if not model_management.supports_cast(load_device, dt):
|
if not model_management.supports_cast(load_device, dt):
|
||||||
load_device = offload_device
|
load_device = offload_device
|
||||||
|
if params['device'] != offload_device:
|
||||||
|
self.cond_stage_model.to(offload_device)
|
||||||
|
logging.warning("Had to shift TE back.")
|
||||||
|
|
||||||
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
if params['device'] == load_device:
|
||||||
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -476,7 +479,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = sd3_clip.SD3ClipModel
|
clip_target.clip = sd3_clip.SD3ClipModel
|
||||||
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config)
|
parameters = 0
|
||||||
|
for c in clip_data:
|
||||||
|
parameters += utils.calculate_parameters(c)
|
||||||
|
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@ -523,15 +530,21 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||||
sd = utils.load_torch_file(ckpt_path)
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, ckpt_path=ckpt_path)
|
||||||
|
if out is None:
|
||||||
|
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, ckpt_path: str | None = None):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
model = None
|
model = None
|
||||||
_model_patcher = None
|
_model_patcher = None
|
||||||
clip_target = None
|
inital_load_device = None
|
||||||
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
@ -540,13 +553,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
|
unet_dtype = model_options.get("weight_dtype", None)
|
||||||
|
|
||||||
|
if unet_dtype is None:
|
||||||
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
@ -570,7 +588,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
parameters = utils.calculate_parameters(clip_sd)
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
@ -589,14 +608,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
logging.debug("left over keys: {}".format(left_over))
|
logging.debug("left over keys: {}".format(left_over))
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
|
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), ckpt_name=os.path.basename(ckpt_path))
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
load_models_gpu([_model_patcher])
|
model_management.load_models_gpu([_model_patcher], force_full_load=True)
|
||||||
|
|
||||||
return (_model_patcher, clip, vae, clipvision)
|
return (_model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options: dict = None): # load unet in diffusers or regular format
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
# Allow loading unets from checkpoint files
|
# Allow loading unets from checkpoint files
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
@ -638,6 +660,7 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
|
|||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
@ -647,15 +670,27 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
|
|||||||
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_diffusion_model(unet_path, model_options: dict = None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
sd = utils.load_torch_file(unet_path)
|
sd = utils.load_torch_file(unet_path)
|
||||||
model = load_unet_state_dict(sd, dtype=dtype)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_unet(unet_path, dtype=None):
|
||||||
|
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|
||||||
|
|
||||||
|
def load_unet_state_dict(sd, dtype=None):
|
||||||
|
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||||
|
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
clip_sd = None
|
clip_sd = None
|
||||||
load_models = [model]
|
load_models = [model]
|
||||||
|
|||||||
@ -16,7 +16,7 @@ except ImportError:
|
|||||||
from typing import Tuple, Sequence, TypeVar, Callable
|
from typing import Tuple, Sequence, TypeVar, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin
|
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from . import clip_model
|
from . import clip_model
|
||||||
from . import model_management
|
from . import model_management
|
||||||
@ -66,7 +66,7 @@ class ClipTokenWeightEncoder:
|
|||||||
|
|
||||||
output = []
|
output = []
|
||||||
for k in range(0, sections):
|
for k in range(0, sections):
|
||||||
z = out[k:k+1]
|
z = out[k:k + 1]
|
||||||
if has_weights:
|
if has_weights:
|
||||||
z_empty = out[-1]
|
z_empty = out[-1]
|
||||||
for i in range(len(z)):
|
for i in range(len(z)):
|
||||||
@ -112,7 +112,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
|
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
|
||||||
|
|
||||||
|
|
||||||
self.operations = ops.manual_cast
|
self.operations = ops.manual_cast
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
@ -389,6 +388,18 @@ def expand_directory_list(directories):
|
|||||||
return list(dirs)
|
return list(dirs)
|
||||||
|
|
||||||
|
|
||||||
|
def bundled_embed(embed, prefix, suffix): # bundled embedding in lora format
|
||||||
|
i = 0
|
||||||
|
out_list = []
|
||||||
|
for k in embed:
|
||||||
|
if k.startswith(prefix) and k.endswith(suffix):
|
||||||
|
out_list.append(embed[k])
|
||||||
|
if len(out_list) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return torch.cat(out_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||||
if isinstance(embedding_directory, str):
|
if isinstance(embedding_directory, str):
|
||||||
embedding_directory = [embedding_directory]
|
embedding_directory = [embedding_directory]
|
||||||
@ -455,8 +466,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
elif embed_key is not None and embed_key in embed:
|
elif embed_key is not None and embed_key in embed:
|
||||||
embed_out = embed[embed_key]
|
embed_out = embed[embed_key]
|
||||||
else:
|
else:
|
||||||
values = embed.values()
|
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
|
||||||
embed_out = next(iter(values))
|
if embed_out is None:
|
||||||
|
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
|
||||||
|
if embed_out is None:
|
||||||
|
values = embed.values()
|
||||||
|
embed_out = next(iter(values))
|
||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
|
|
||||||
@ -631,6 +646,7 @@ class SDTokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
|
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
|
||||||
|
|
||||||
|
|
||||||
@ -664,6 +680,7 @@ class SD1Tokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module):
|
class SD1ClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
|
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -640,9 +640,9 @@ class Flux(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
memory_usage_factor = 2.6
|
memory_usage_factor = 2.8
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|||||||
@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import utils
|
from . import utils
|
||||||
@ -30,6 +48,7 @@ class BASE:
|
|||||||
memory_usage_factor = 2.0
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
custom_operations = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
|
|||||||
@ -1,3 +1,20 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
@ -472,8 +489,33 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
|
||||||
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
|
k = "{}.attn.".format(prefix_from)
|
||||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
||||||
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
|
||||||
|
block_map = {
|
||||||
|
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||||
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||||
|
"norm1.linear.weight": "img_mod.lin.weight",
|
||||||
|
"norm1.linear.bias": "img_mod.lin.bias",
|
||||||
|
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||||
|
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||||
|
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||||
|
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||||
|
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||||
|
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||||
|
"ff.net.2.weight": "img_mlp.2.weight",
|
||||||
|
"ff.net.2.bias": "img_mlp.2.bias",
|
||||||
|
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||||
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||||
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||||
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||||
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||||
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@ -489,15 +531,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
|
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||||
|
|
||||||
block_map = {#TODO
|
block_map = {
|
||||||
|
"norm.linear.weight": "modulation.lin.weight",
|
||||||
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
|
"proj_out.weight": "linear2.weight",
|
||||||
|
"proj_out.bias": "linear2.bias",
|
||||||
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||||
|
|
||||||
MAP_BASIC = { #TODO
|
MAP_BASIC = {
|
||||||
|
("final_layer.linear.bias", "proj_out.bias"),
|
||||||
|
("final_layer.linear.weight", "proj_out.weight"),
|
||||||
|
("img_in.bias", "x_embedder.bias"),
|
||||||
|
("img_in.weight", "x_embedder.weight"),
|
||||||
|
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||||
|
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||||
|
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||||
|
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||||
|
("txt_in.bias", "context_embedder.bias"),
|
||||||
|
("txt_in.weight", "context_embedder.weight"),
|
||||||
|
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||||
|
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||||
|
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||||
|
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||||
|
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||||
|
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||||
|
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||||
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
@ -859,6 +927,8 @@ def seed_for_block(seed):
|
|||||||
numpy_rng_state = np.random.get_state()
|
numpy_rng_state = np.random.get_state()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
cuda_rng_state = torch.cuda.get_rng_state_all()
|
cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||||
|
else:
|
||||||
|
cuda_rng_state = None
|
||||||
|
|
||||||
# Set the new seed
|
# Set the new seed
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ class CLIPTextEncodeHunyuanDiT:
|
|||||||
cond = output.pop("cond")
|
cond = output.pop("cond")
|
||||||
return ([[cond, output]], )
|
return ([[cond, output]], )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,9 +11,10 @@ from typing import Any, Dict, Optional, List, Callable, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||||
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
||||||
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel
|
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel, AutoModelForCausalLM
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from comfy import model_management
|
||||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
from comfy.language.language_types import ProcessorResult
|
from comfy.language.language_types import ProcessorResult
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
@ -28,9 +29,9 @@ _AUTO_CHAT_TEMPLATE = "default"
|
|||||||
try:
|
try:
|
||||||
from llava import model
|
from llava import model
|
||||||
|
|
||||||
logging.info("Additional LLaVA models are now supported")
|
logging.debug("Additional LLaVA models are now supported")
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
||||||
|
|
||||||
# aka kwargs type
|
# aka kwargs type
|
||||||
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
||||||
@ -129,7 +130,7 @@ class TransformersGenerationConfig(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL",)
|
"model": ("MODEL", {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,13 +248,22 @@ class TransformersLoader(CustomNode):
|
|||||||
**hub_kwargs
|
**hub_kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
|
# import flash_attn
|
||||||
except Exception as exc_info:
|
# from_pretrained_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
# not yet supported by automodel
|
# except ImportError:
|
||||||
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
|
# logging.debug("install flash_attn for improved performance using language nodes")
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
|
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
|
||||||
|
|
||||||
|
if config_dict["model_type"] == "llava_next":
|
||||||
|
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
except Exception:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
|
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
|
||||||
@ -265,6 +275,10 @@ class TransformersLoader(CustomNode):
|
|||||||
processor = None
|
processor = None
|
||||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
|
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
|
||||||
|
|
||||||
|
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
|
model.enable_xformers_memory_efficient_attention()
|
||||||
|
logging.debug("enabled xformers memory efficient attention")
|
||||||
|
|
||||||
model_managed = TransformersManagedModel(
|
model_managed = TransformersManagedModel(
|
||||||
repo_id=ckpt_name,
|
repo_id=ckpt_name,
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
@ -263,6 +263,7 @@ class CLIPSave:
|
|||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
|
metadata["format"] = "pt"
|
||||||
metadata["prompt"] = prompt_info
|
metadata["prompt"] = prompt_info
|
||||||
if extra_pnginfo is not None:
|
if extra_pnginfo is not None:
|
||||||
for x in extra_pnginfo:
|
for x in extra_pnginfo:
|
||||||
|
|||||||
@ -161,7 +161,7 @@ class BooleanRequestParameter(CustomNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("BOOLEAN",)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "api/openapi"
|
CATEGORY = "api/openapi"
|
||||||
|
|
||||||
|
|||||||
@ -108,3 +108,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||||
"ControlNetApplySD3": ControlNetApplySD3,
|
"ControlNetApplySD3": ControlNetApplySD3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
# Sampling
|
||||||
|
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
||||||
|
}
|
||||||
|
|||||||
@ -16,7 +16,7 @@ try:
|
|||||||
from spandrel import MAIN_REGISTRY
|
from spandrel import MAIN_REGISTRY
|
||||||
|
|
||||||
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
||||||
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
logging.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -26,20 +26,16 @@ class UpscaleModelManageable(ModelManageable):
|
|||||||
self.ckpt_name = ckpt_name
|
self.ckpt_name = ckpt_name
|
||||||
self.model_descriptor = model_descriptor
|
self.model_descriptor = model_descriptor
|
||||||
self.model = model_descriptor.model
|
self.model = model_descriptor.model
|
||||||
self.load_device = model_management.unet_offload_device()
|
self.load_device = model_management.get_torch_device()
|
||||||
self.offload_device = model_management.unet_offload_device()
|
self.offload_device = model_management.unet_offload_device()
|
||||||
self._current_device = self.offload_device
|
self._input_size = (1, 512, 512)
|
||||||
self._lowvram_patch_counter = 0
|
|
||||||
|
|
||||||
# Private properties for image sizes and channels
|
|
||||||
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
|
|
||||||
self._input_channels = model_descriptor.input_channels
|
self._input_channels = model_descriptor.input_channels
|
||||||
self._output_channels = model_descriptor.output_channels
|
self._output_channels = model_descriptor.output_channels
|
||||||
self.tile = 512
|
self.tile = 512
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_device(self) -> torch.device:
|
def current_device(self) -> torch.device:
|
||||||
return self._current_device
|
return self.model_descriptor.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_size(self) -> tuple[int, int, int]:
|
def input_size(self) -> tuple[int, int, int]:
|
||||||
@ -65,21 +61,14 @@ class UpscaleModelManageable(ModelManageable):
|
|||||||
def is_clone(self, other: Any) -> bool:
|
def is_clone(self, other: Any) -> bool:
|
||||||
return isinstance(other, UpscaleModelManageable) and self.model is other.model
|
return isinstance(other, UpscaleModelManageable) and self.model is other.model
|
||||||
|
|
||||||
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
def clone_has_same_weights(self, clone) -> bool:
|
||||||
return self.is_clone(clone)
|
return self.is_clone(clone)
|
||||||
|
|
||||||
def model_size(self) -> int:
|
def model_size(self) -> int:
|
||||||
# Calculate the size of the model parameters
|
|
||||||
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
|
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
|
||||||
|
|
||||||
# Get the byte size of the model's dtype
|
|
||||||
dtype_size = torch.finfo(self.model_dtype()).bits // 8
|
dtype_size = torch.finfo(self.model_dtype()).bits // 8
|
||||||
|
|
||||||
# Calculate the memory required for input and output images
|
|
||||||
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
|
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
|
||||||
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
|
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
|
||||||
|
|
||||||
# Add some extra memory for processing
|
|
||||||
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
|
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
|
||||||
|
|
||||||
return model_params_size + input_size + output_size + extra_memory
|
return model_params_size + input_size + output_size + extra_memory
|
||||||
@ -95,30 +84,16 @@ class UpscaleModelManageable(ModelManageable):
|
|||||||
|
|
||||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
self.model.to(device=device_to)
|
self.model.to(device=device_to)
|
||||||
self._current_device = device_to
|
|
||||||
self._lowvram_patch_counter += 1
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||||
if patch_weights:
|
self.model.to(device=device_to)
|
||||||
self.model.to(device=device_to)
|
|
||||||
self._current_device = device_to
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
if unpatch_weights:
|
self.model.to(device=offload_device)
|
||||||
self.model.to(device=offload_device)
|
|
||||||
self._current_device = offload_device
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@property
|
|
||||||
def lowvram_patch_counter(self) -> int:
|
|
||||||
return self._lowvram_patch_counter
|
|
||||||
|
|
||||||
@lowvram_patch_counter.setter
|
|
||||||
def lowvram_patch_counter(self, value: int):
|
|
||||||
self._lowvram_patch_counter = value
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.ckpt_name is not None:
|
if self.ckpt_name is not None:
|
||||||
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||||
|
|||||||
@ -46,7 +46,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
def test_known_repos(tmp_path_factory):
|
async def test_known_repos(tmp_path_factory):
|
||||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||||
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||||
|
|
||||||
|
|||||||
@ -39,12 +39,17 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
|||||||
workflow = json.loads(workflow_file.read_text())
|
workflow = json.loads(workflow_file.read_text())
|
||||||
|
|
||||||
prompt = Prompt.validate(workflow)
|
prompt = Prompt.validate(workflow)
|
||||||
# todo: add all the models we want to test a bit more elegantly
|
# todo: add all the models we want to test a bit m2ore elegantly
|
||||||
outputs = await client.queue_prompt(prompt)
|
outputs = await client.queue_prompt(prompt)
|
||||||
|
|
||||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||||
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
|
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
|
||||||
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
||||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
||||||
assert outputs[save_image_node_id]["audio"][0]["filename"] is not None
|
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
|
||||||
|
elif any(v.class_type == "PreviewString" for v in prompt.values()):
|
||||||
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
|
||||||
|
output_str = outputs[save_image_node_id]["string"][0]
|
||||||
|
assert output_str is not None
|
||||||
|
assert len(output_str) > 0
|
||||||
|
|||||||
53
tests/inference/workflows/image-upscale-with-model-0.json
Normal file
53
tests/inference/workflows/image-upscale-with-model-0.json
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
{
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||||
|
"name": "",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "ImageRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"19": {
|
||||||
|
"inputs": {
|
||||||
|
"model_name": "RealESRGAN_x4plus.pth"
|
||||||
|
},
|
||||||
|
"class_type": "UpscaleModelLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Upscale Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"20": {
|
||||||
|
"inputs": {
|
||||||
|
"upscale_model": [
|
||||||
|
"19",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"image": [
|
||||||
|
"17",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ImageUpscaleWithModel",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Upscale Image (using Model)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"21": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"20",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
91
tests/inference/workflows/llava-0.json
Normal file
91
tests/inference/workflows/llava-0.json
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
|
"subfolder": ""
|
||||||
|
},
|
||||||
|
"class_type": "TransformersLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "TransformersLoader"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"max_new_tokens": 512,
|
||||||
|
"repetition_penalty": 0,
|
||||||
|
"seed": 2013744903,
|
||||||
|
"use_cache": true,
|
||||||
|
"__tokens": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. </s>",
|
||||||
|
"model": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"tokens": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "TransformersGenerate",
|
||||||
|
"_meta": {
|
||||||
|
"title": "TransformersGenerate"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"prompt": "Who is this?",
|
||||||
|
"chat_template": "llava-v1.6-mistral-7b-hf",
|
||||||
|
"model": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "OneShotInstructTokenize",
|
||||||
|
"_meta": {
|
||||||
|
"title": "OneShotInstructTokenize"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"value": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"output": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. "
|
||||||
|
},
|
||||||
|
"class_type": "PreviewString",
|
||||||
|
"_meta": {
|
||||||
|
"title": "PreviewString"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||||
|
"name": "",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "ImageRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"upscale_method": "nearest-exact",
|
||||||
|
"megapixels": 1,
|
||||||
|
"image": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ImageScaleToTotalPixels",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageScaleToTotalPixels"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
60
tests/inference/workflows/phi-3-0.json
Normal file
60
tests/inference/workflows/phi-3-0.json
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
"subfolder": ""
|
||||||
|
},
|
||||||
|
"class_type": "TransformersLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "TransformersLoader"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"max_new_tokens": 512,
|
||||||
|
"repetition_penalty": 0,
|
||||||
|
"seed": 2514389986,
|
||||||
|
"use_cache": true,
|
||||||
|
"__tokens": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple.<|end|>",
|
||||||
|
"model": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"tokens": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "TransformersGenerate",
|
||||||
|
"_meta": {
|
||||||
|
"title": "TransformersGenerate"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"prompt": "What comes after apple?",
|
||||||
|
"chat_template": "phi-3",
|
||||||
|
"model": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "OneShotInstructTokenize",
|
||||||
|
"_meta": {
|
||||||
|
"title": "OneShotInstructTokenize"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"value": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"output": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple."
|
||||||
|
},
|
||||||
|
"class_type": "PreviewString",
|
||||||
|
"_meta": {
|
||||||
|
"title": "PreviewString"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
223
tests/inference/workflows/sdxl-union-controlnet-0.json
Normal file
223
tests/inference/workflows/sdxl-union-controlnet-0.json
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "sd_xl_base_1.0.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Checkpoint"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"inputs": {
|
||||||
|
"strength": 0.5,
|
||||||
|
"start_percent": 0,
|
||||||
|
"end_percent": 1,
|
||||||
|
"positive": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"control_net": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"image": [
|
||||||
|
"9",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ControlNetApplyAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Apply ControlNet (Advanced)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "a girl with blue hair",
|
||||||
|
"clip": [
|
||||||
|
"1",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": true,
|
||||||
|
"noise_seed": 969970429360105,
|
||||||
|
"cfg": 8,
|
||||||
|
"model": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"2",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"2",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"sampler": [
|
||||||
|
"13",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"sigmas": [
|
||||||
|
"11",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"12",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SamplerCustom",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SamplerCustom"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "",
|
||||||
|
"clip": [
|
||||||
|
"1",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {
|
||||||
|
"type": "canny/lineart/anime_lineart/mlsd",
|
||||||
|
"control_net": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SetUnionControlNetType",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SetUnionControlNetType"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "ControlNetLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load ControlNet Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"low_threshold": 0.4,
|
||||||
|
"high_threshold": 0.8,
|
||||||
|
"image": [
|
||||||
|
"18",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "Canny",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Canny"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"inputs": {
|
||||||
|
"model_type": "SDXL",
|
||||||
|
"steps": 25,
|
||||||
|
"denoise": 1
|
||||||
|
},
|
||||||
|
"class_type": "AlignYourStepsScheduler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "AlignYourStepsScheduler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Empty Latent Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"inputs": {
|
||||||
|
"eta": 1,
|
||||||
|
"s_noise": 1
|
||||||
|
},
|
||||||
|
"class_type": "SamplerEulerAncestral",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SamplerEulerAncestral"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"14": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"1",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"15": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"14",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||||
|
"name": "",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "ImageRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"18": {
|
||||||
|
"inputs": {
|
||||||
|
"upscale_method": "lanczos",
|
||||||
|
"megapixels": 1,
|
||||||
|
"image": [
|
||||||
|
"17",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ImageScaleToTotalPixels",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageScaleToTotalPixels"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
302
tests/inference/workflows/sdxl-union-controlnet-1.json
Normal file
302
tests/inference/workflows/sdxl-union-controlnet-1.json
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "sd_xl_base_1.0.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Checkpoint"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"inputs": {
|
||||||
|
"strength": 0.5,
|
||||||
|
"start_percent": 0,
|
||||||
|
"end_percent": 1,
|
||||||
|
"positive": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"control_net": [
|
||||||
|
"28",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"image": [
|
||||||
|
"9",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ControlNetApplyAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Apply ControlNet (Advanced)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"text": [
|
||||||
|
"26",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"clip": [
|
||||||
|
"1",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": [
|
||||||
|
"23",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"noise_seed": [
|
||||||
|
"20",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"cfg": [
|
||||||
|
"19",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"model": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"2",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"2",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"sampler": [
|
||||||
|
"24",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"sigmas": [
|
||||||
|
"11",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"12",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SamplerCustom",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SamplerCustom"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "",
|
||||||
|
"clip": [
|
||||||
|
"1",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"low_threshold": 0.4,
|
||||||
|
"high_threshold": 0.8,
|
||||||
|
"image": [
|
||||||
|
"18",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "Canny",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Canny"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"inputs": {
|
||||||
|
"model_type": "SDXL",
|
||||||
|
"steps": 25,
|
||||||
|
"denoise": 1
|
||||||
|
},
|
||||||
|
"class_type": "AlignYourStepsScheduler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "AlignYourStepsScheduler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Empty Latent Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"14": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"1",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"15": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"14",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||||
|
"name": "",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "ImageRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"18": {
|
||||||
|
"inputs": {
|
||||||
|
"upscale_method": "lanczos",
|
||||||
|
"megapixels": 1,
|
||||||
|
"image": [
|
||||||
|
"17",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ImageScaleToTotalPixels",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageScaleToTotalPixels"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"19": {
|
||||||
|
"inputs": {
|
||||||
|
"value": 8,
|
||||||
|
"name": "cfg",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "FloatRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "FloatRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"20": {
|
||||||
|
"inputs": {
|
||||||
|
"value": 0,
|
||||||
|
"name": "seed",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "IntRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "IntRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"23": {
|
||||||
|
"inputs": {
|
||||||
|
"value": true,
|
||||||
|
"name": "add_noise",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "BooleanRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "BooleanRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"inputs": {
|
||||||
|
"sampler_name": [
|
||||||
|
"25",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerSelect",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSamplerSelect"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"25": {
|
||||||
|
"inputs": {
|
||||||
|
"value": "euler",
|
||||||
|
"name": "sampler_name",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "StringEnumRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "StringEnumRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"26": {
|
||||||
|
"inputs": {
|
||||||
|
"value": "a girl with blue hair",
|
||||||
|
"name": "",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "StringRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "StringRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"27": {
|
||||||
|
"inputs": {
|
||||||
|
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "ControlNetLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load ControlNet Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"28": {
|
||||||
|
"inputs": {
|
||||||
|
"type": "canny/lineart/anime_lineart/mlsd",
|
||||||
|
"control_net": [
|
||||||
|
"27",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SetUnionControlNetType",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SetUnionControlNetType"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user