Merge commit '39fb74c5bd13a1dccf4d7293a2f7a755d9f43cbd' of github.com:comfyanonymous/ComfyUI

- Improvements to tests
 - Fixes model management
 - Fixes issues with language nodes
This commit is contained in:
doctorpangloss 2024-08-13 20:08:56 -07:00
commit 0549f35e85
46 changed files with 2304 additions and 425 deletions

View File

@ -0,0 +1,53 @@
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Pull Request CI Workflow Runs
on:
pull_request_target:
types: [labeled]
jobs:
pr-test-stable:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
use_prior_commit: 'true'
comment:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
})

95
.github/workflows/test-ci.yml vendored Normal file
View File

@ -0,0 +1,95 @@
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Full Comfy CI Workflow Runs
on:
push:
branches:
- master
paths-ignore:
- 'app/**'
- 'input/**'
- 'output/**'
- 'notebooks/**'
- 'script_examples/**'
- '.github/**'
- 'web/**'
workflow_dispatch:
jobs:
test-stable:
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-win-nightly:
strategy:
fail-fast: true
matrix:
os: [windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:
fail-fast: false
matrix:
os: [macos, linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}

3
.gitignore vendored
View File

@ -176,4 +176,5 @@ cython_debug/
/tests-ui/data/object_info.json
/user/
*.log
web_custom_versions/
web_custom_versions/
.DS_Store

View File

@ -14,9 +14,9 @@ from typing import TypedDict
import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING
from comfy.cmd.folder_paths import add_model_folder_path
from comfy.component_model.files import get_package_as_path
from ..cli_args import DEFAULT_VERSION_STRING
from ..cmd.folder_paths import add_model_folder_path
from ..component_model.files import get_package_as_path
REQUEST_TIMEOUT = 10 # seconds

View File

@ -10,7 +10,6 @@ import time
import traceback
import typing
from os import PathLike
from pathlib import PurePath
from typing import List, Optional, Tuple
import lazy_object_proxy
@ -34,9 +33,11 @@ from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOption
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy.graph_utils import is_link, GraphBuilder
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
# order matters
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
class IsChangedCache:
@ -446,6 +447,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted
}
if isinstance(ex, model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
model_management.unload_all_models()
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id)

View File

@ -55,7 +55,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
logging.debug("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)

View File

@ -838,7 +838,7 @@ class PromptServer(ExecutorToClientProgress):
self.port = port
if verbose:
logging.info("Starting server\n")
logging.info("Starting server")
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
if call_on_start is not None:
call_on_start(address, port)

View File

@ -1,4 +1,24 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from enum import Enum
import math
import os
import logging
@ -12,7 +32,8 @@ from . import latent_formats
from .cldm import cldm, mmdit
from .t2i_adapter import adapter
from .ldm.cascade import controlnet
from .ldm import hydit, flux
from .ldm.cascade import controlnet as cascade_controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -33,6 +54,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)
class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
@ -51,6 +76,8 @@ class ControlBase:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint
@ -93,6 +120,8 @@ class ControlBase:
c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@ -113,7 +142,10 @@ class ControlBase:
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x)
x *= self.strength
if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype:
x = x.to(output_dtype)
@ -142,7 +174,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, ckpt_name: str = None):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, ckpt_name: str = None):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
@ -154,6 +186,8 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@ -191,13 +225,16 @@ class ControlNet(ControlBase):
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
extra = self.extra_args.copy()
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype)
def copy(self):
@ -286,6 +323,7 @@ class ControlLora(ControlNet):
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
@ -338,12 +376,8 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype):
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet_mmdit(sd):
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config = model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
def controlnet_config(sd):
model_config = model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes
@ -356,14 +390,27 @@ def load_controlnet_mmdit(sd):
else:
operations = ops.disable_weight_init
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd):
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
@ -371,8 +418,30 @@ def load_controlnet_mmdit(sd):
return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
@ -430,7 +499,10 @@ def load_controlnet(ckpt_path, model=None):
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
return load_controlnet_mmdit(controlnet_data)
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
else:
return load_controlnet_mmdit(controlnet_data)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@ -590,11 +662,11 @@ def load_t2i_adapter(t2i_data):
xl = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
elif "backbone.0.0.weight" in keys:
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 32
upscale_algorithm = 'bilinear'
elif "backbone.10.blocks.0.weight" in keys:
model_ad = controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 1
upscale_algorithm = 'nearest-exact'
else:

View File

@ -28,7 +28,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
unet = None
if unet_path is not None:
unet = sd.load_unet(unet_path)
unet = sd.load_diffusion_model(unet_path)
clip = None
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])

View File

@ -8,7 +8,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
import comfy.model_patcher
from .. import model_patcher
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@ -1032,7 +1032,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
@ -1058,7 +1058,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):

View File

@ -77,19 +77,9 @@ class TransformersManagedModel(ModelManageable):
return self.model.config.to_dict()
@property
def lowvram_patch_counter(self):
return 0
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
warnings.warn("Not supported")
pass
load_device: torch.device
offload_device: torch.device
model: PreTrainedModel
@property
def current_device(self) -> torch.device:
return self.model.device
@ -127,12 +117,10 @@ class TransformersManagedModel(ModelManageable):
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
return self.model.to(device=device_to)
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
return self.model.to(device=offload_device)
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
@ -177,7 +165,7 @@ class TransformersManagedModel(ModelManageable):
self.processor.to(device=self.load_device)
assert "<image>" in prompt.lower(), "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor"
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
batch_feature: BatchFeature = self.processor([prompt], images=images.unbind(), padding=True, return_tensors="pt")
if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device)
assert "input_ids" in batch_feature

View File

@ -0,0 +1,104 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
return {"output": (controlnet_block_res_samples * 10)[:19]}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)

View File

@ -2,12 +2,12 @@ import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from .math import attention, rope
from ... import ops
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__()
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@ -48,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__()
@ -94,14 +91,6 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
@ -163,22 +152,21 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), pe=pe)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
@ -186,8 +174,12 @@ class DoubleStreamBlock(nn.Module):
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
return img, txt
@ -232,14 +224,17 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
return x
class LastLayer(nn.Module):

View File

@ -38,7 +38,7 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
@ -83,7 +83,8 @@ class Flux(nn.Module):
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
@ -94,6 +95,7 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@ -112,8 +114,15 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
if control is not None: #Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
@ -123,7 +132,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@ -138,5 +147,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
@ -78,10 +78,9 @@ def apply_rotary_emb(
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xq_out = (xq * cos + rotate_half(xq) * sin)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
xk_out = (xk * cos + rotate_half(xk) * sin)
else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]

View File

@ -0,0 +1,321 @@
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder,
PatchEmbed,
RMSNorm,
)
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
class HunYuanControlNet(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
def __init__(
self,
input_size: tuple = 128,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1408,
depth: int = 40,
num_heads: int = 16,
mlp_ratio: float = 4.3637,
text_states_dim=1024,
text_states_dim_t5=2048,
text_len=77,
text_len_t5=256,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
size_cond=False,
use_style_cond=False,
learn_sigma=True,
norm="layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
self.latent_format = comfy.latent_formats.SDXL
self.mlp_t5 = nn.Sequential(
nn.Linear(
self.text_states_dim_t5,
self.text_states_dim_t5 * 4,
bias=True,
dtype=dtype,
device=device,
),
nn.SiLU(),
nn.Linear(
self.text_states_dim_t5 * 4,
self.text_states_dim,
bias=True,
dtype=dtype,
device=device,
),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(
self.text_len + self.text_len_t5,
self.text_states_dim,
dtype=dtype,
device=device,
)
)
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(
self.text_len_t5,
self.text_states_dim_t5,
num_heads=8,
output_dim=pooler_out_dim,
dtype=dtype,
device=device,
operations=operations,
)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(
1, hidden_size, dtype=dtype, device=device
)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
hidden_size,
dtype=dtype,
device=device,
operations=operations,
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations
)
self.extra_embedder = nn.Sequential(
operations.Linear(
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
),
nn.SiLU(),
operations.Linear(
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunYuanDiTBlock(
hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=False,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(19)
]
)
# Input zero linear for the first block
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
operations.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device
)
for _ in range(len(self.blocks))
]
)
def forward(
self,
x,
hint,
timesteps,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
return_dict=False,
**kwarg,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
condition = hint
if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
text_states = context # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:, -self.text_len :] = torch.where(
text_states_mask[:, -self.text_len :].unsqueeze(2),
text_states[:, -self.text_len :],
padding[: self.text_len],
)
text_states_t5[:, -self.text_len_t5 :] = torch.where(
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
text_states_t5[:, -self.text_len_t5 :],
padding[self.text_len :],
)
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# _, _, oh, ow = x.shape
# th, tw = oh // self.patch_size, ow // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(
x, self.patch_size, self.hidden_size // self.num_heads
) # (cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(timesteps, dtype=self.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
# if image_meta_size is not None:
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
# if image_meta_size.dtype != self.dtype:
# image_meta_size = image_meta_size.half()
# image_meta_size = image_meta_size.view(-1, 6 * 256)
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if style is not None:
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Deal with Condition =========================
condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
return {"output": controls}

View File

@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
rope = (rope[0].to(x), rope[1].to(x))
return rope
@ -91,6 +92,8 @@ class HunYuanDiTBlock(nn.Module):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
if cat.dtype != x.dtype:
cat = cat.to(x.dtype)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
@ -362,6 +365,8 @@ class HunYuanDiT(nn.Module):
c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None
if control:
controls = control.get("output", None)
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):

View File

@ -411,17 +411,17 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
optimized_attention = attention_basic
if model_management.xformers_enabled():
logging.info("Using xformers cross attention")
logging.debug("Using xformers cross attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch cross attention")
logging.debug("Using pytorch cross attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
logging.info("Using split optimization for cross attention")
logging.debug("Using split optimization for cross attention")
optimized_attention = attention_split
else:
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
logging.debug("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention

View File

@ -268,13 +268,13 @@ class AttnBlock(nn.Module):
padding=0)
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
logging.debug("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
logging.debug("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
logging.info("Using split attention in VAE")
logging.debug("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x):

View File

@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import logging
from . import utils
from . import model_base
@ -216,11 +234,17 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k
for k in sdk: #OneTrainer SD3 lora
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
@ -243,6 +267,7 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:

View File

@ -1,7 +1,26 @@
import logging
from enum import Enum
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import logging
import math
from enum import Enum
from typing import TypeVar, Type
import torch
from . import conds
@ -11,15 +30,16 @@ from . import ops
from . import utils
from .ldm.audio.dit import AudioDiffusionTransformer
from .ldm.audio.embedders import NumberConditioner
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
from .ldm.cascade.stage_b import StageB
from .ldm.cascade.stage_c import StageC
from .ldm.flux import model as flux_model
from .ldm.hydit.models import HunYuanDiT
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.aura.mmdit import MMDiT as AuraMMDiT
from .ldm.hydit.models import HunYuanDiT
from .ldm.flux import model as flux_model
class ModelType(Enum):
EPS = 1
@ -68,26 +88,33 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
TModule = TypeVar('TModule', bound=torch.nn.Module)
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
def __init__(self, model_config, model_type=ModelType.EPS, device: torch.device = None, unet_model: Type[TModule] = UNetModel):
super().__init__()
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device: torch.device = device
if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None:
operations = ops.manual_cast
if model_config.custom_operations is None:
if self.manual_cast_dtype is not None:
operations = ops.manual_cast
else:
operations = ops.disable_weight_init
else:
operations = ops.disable_weight_init
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if model_management.force_channels_last():
# todo: ???
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@ -96,7 +123,7 @@ class BaseModel(torch.nn.Module):
self.adm_channels = 0
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
@ -669,6 +696,7 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l]
return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT)
@ -701,6 +729,7 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux)

View File

@ -136,8 +136,8 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = 19
dit_config["depth_single_blocks"] = 38
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
@ -494,7 +494,12 @@ def model_config_from_diffusers_unet(state_dict):
def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {}
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
@ -520,7 +525,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
old_weight = out_sd.get(t[0], None)
if old_weight is None:
old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
exp = list(weight.shape)
exp[offset[0]] = offset[1] + offset[2]
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
new[:old_weight.shape[0]] = old_weight
old_weight = new
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:

View File

@ -571,7 +571,7 @@ def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_i
# fix path representation
local_files = set(f.replace("\\", "/") for f in local_files)
# remove .huggingface
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface"))
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface") and not f.startswith(f"{repo_id}/.cache"))
# local_files.issubsetof(repo_files)
if len(local_files) > 0 and local_files.issubset(repo_files):
local_dirs_snapshots.append(str(local_path))

View File

@ -1,3 +1,20 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import logging
@ -143,10 +160,10 @@ if torch.cuda.is_available() and hasattr(torch.version, "hip") and torch.version
logging.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
logging.debug("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch.version.__version__))
logging.debug("pytorch version: {}".format(torch.version.__version__))
except:
pass
@ -171,7 +188,7 @@ else:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
logging.info("xformers version: {}".format(XFORMERS_VERSION))
logging.debug("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"):
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
@ -263,12 +280,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED
logging.info(f"Set vram state to: {vram_state.name}")
logging.debug(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY:
logging.info("Disabling smart memory management")
logging.debug("Disabling smart memory management")
def get_torch_device_name(device):
@ -288,7 +305,7 @@ def get_torch_device_name(device):
try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
logging.debug("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
@ -315,9 +332,12 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_offloaded_memory(self):
return self.model.model_size() - self.model.loaded_size()
def model_memory_required(self, device):
if device == self.model.current_device:
return 0
if device == self.model.current_loaded_device():
return self.model_offloaded_memory()
else:
return self.model_memory()
@ -329,15 +349,21 @@ class LoadedModel:
load_weights = not self.weights_loaded
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram)
else:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
@ -346,15 +372,24 @@ class LoadedModel:
return self.real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0:
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
return True
return False
def model_unload(self, unpatch_weights=True):
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
return True
def model_use_more_vram(self, extra_memory):
return self.model.partially_load(self.device, extra_memory)
def __eq__(self, other):
return self.model is other.model
@ -366,39 +401,59 @@ class LoadedModel:
return f"<LoadedModel>"
def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
if m.device == device:
extra_memory -= m.model_use_more_vram(extra_memory)
if extra_memory <= 0:
break
def offloaded_memory(loaded_models, device):
offloaded_mem = 0
for m in loaded_models:
if m.device == device:
offloaded_mem += m.model_offloaded_memory()
return offloaded_mem
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock:
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
return _unload_model_clones(model, unload_weights_only, force_unload)
if len(to_unload) == 0:
return True
same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
def _unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if len(to_unload) == 0:
return True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
for i in to_unload:
logging.debug("unload clone {}{}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
return unload_weight
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
@tracer.start_as_current_span("Free Memory")
@ -406,126 +461,158 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
span = get_current_span()
span.set_attribute("memory_required", memory_required)
with model_management_lock:
unloaded_models: List[LoadedModel] = []
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_models.append(i)
for i in sorted(unloaded_models, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_models) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
unloaded_models = _free_memory(memory_required, device, keep_loaded)
span.set_attribute("unloaded_models", list(map(str, unloaded_models)))
return unloaded_models
def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
unloaded_model = []
can_unload = []
unloaded_models = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):
i = x[-1]
memory_to_free = None
if not DISABLE_SMART_MEMORY:
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
if len(unloaded_model) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
return unloaded_models
@tracer.start_as_current_span("Load Models GPU")
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None) -> None:
global vram_state
def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
span = get_current_span()
if memory_required != 0:
span.set_attribute("memory_required", memory_required)
with model_management_lock:
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required)
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
to_load = list(map(str, models))
span.set_attribute("models", to_load)
logging.info(f"Loaded {to_load}")
models = set(models)
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
try:
loaded_model_index = current_loaded_models.index(loaded_model)
except ValueError:
loaded_model_index = None
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
global vram_state
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
models_to_load.append(loaded_model)
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
models = set(models)
models_to_load = []
models_already_loaded = []
models_freed = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
models_freed: List[LoadedModel] = []
try:
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
models_freed += free_memory(extra_mem, d, models_already_loaded)
return
loaded_model_index = current_loaded_models.index(loaded_model)
except:
loaded_model_index = None
total_memory_required = {}
for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
for device in total_memory_required:
if device != torch.device("cpu"):
# todo: where does 1.3 come from?
models_freed += free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
if loaded is None:
models_to_load.append(loaded_model)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
models_to_load = free_memory(minimum_memory_required, d)
models_freed += models_to_load
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0:
return
finally:
span.set_attribute("models", list(map(str, models)))
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
logging.info(f"Requested to load {','.join(map(str, models))}, models loaded: {','.join(map(str, models_to_load))}, models freed: {','.join(map(str, models_freed))}")
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_already_loaded:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required:
if device != torch.device("cpu"):
models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
span = get_current_span()
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
@_deprecate_method(message="Use load_models_gpu instead", version="0.0.2")
@ -605,6 +692,7 @@ def unet_initial_load_device(parameters, dtype):
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if args.bf16_unet:
return torch.bfloat16
@ -629,12 +717,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes:
return torch.float16
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
@ -651,13 +749,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.flo
if bf16_supported and weight_dtype == torch.bfloat16:
return None
if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
if dt == torch.bfloat16 and bf16_supported:
return torch.bfloat16
elif bf16_supported and torch.bfloat16 in supported_dtypes:
return torch.bfloat16
else:
return torch.float32
return torch.float32
def text_encoder_offload_device():
@ -679,6 +777,21 @@ def text_encoder_device():
return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
if is_device_mps(load_device):
return offload_device
mem_l = get_free_memory(load_device)
mem_o = get_free_memory(offload_device)
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
return load_device
else:
return offload_device
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn

View File

@ -1,8 +1,23 @@
from __future__ import annotations
from typing import Protocol, Optional, Any
import dataclasses
from typing import Protocol, Optional, TypeVar, runtime_checkable
import torch
import torch.nn
ModelManageableT = TypeVar('ModelManageableT', bound='ModelManageable')
@runtime_checkable
class DeviceSettable(Protocol):
@property
def device(self) -> torch.device:
...
@device.setter
def device(self, value: torch.device):
...
class ModelManageable(Protocol):
@ -22,32 +37,92 @@ class ModelManageable(Protocol):
@property
def current_device(self) -> torch.device:
...
return next(self.model.parameters()).device
def is_clone(self, other: Any) -> bool:
...
def is_clone(self, other: ModelManageableT) -> bool:
return other.model is self.model
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
...
def clone_has_same_weights(self, clone: ModelManageableT) -> bool:
return clone.model is self.model
def model_size(self) -> int:
...
from .model_management import module_size
return module_size(self.model)
def model_patches_to(self, arg: torch.device | torch.dtype):
...
pass
def model_dtype(self) -> torch.dtype:
...
return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
self.patch_model(device_to=device_to, patch_weights=False)
return self.model
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
"""
Loads the model to the device
:param device_to: the device to move the model weights to
:param patch_weights: True if the patch's weights should also be moved
:return:
"""
...
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
"""
Unloads the model by moving it to the offload device
:param offload_device:
:param unpatch_weights:
:return:
"""
...
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
...
def lowvram_patch_counter(self) -> int:
return 0
def partially_load(self, device_to: torch.device, extra_memory=0) -> int:
self.patch_model(device_to=device_to)
return self.model_size()
def partially_unload(self, device_to: torch.device, extra_memory=0) -> int:
self.unpatch_model(device_to)
return self.model_size()
def memory_required(self, input_shape) -> int:
from comfy.model_base import BaseModel
if isinstance(self.model, BaseModel):
return self.model.memory_required(input_shape=input_shape)
else:
# todo: why isn't this true?
return self.model_size()
def loaded_size(self) -> int:
if self.current_loaded_device() == self.load_device:
return self.model_size()
return 0
def current_loaded_device(self) -> torch.device:
return self.current_device
@dataclasses.dataclass
class MemoryMeasurements:
model: torch.nn.Module | DeviceSettable
model_loaded_weight_memory: int = 0
lowvram_patch_counter: int = 0
model_lowvram: bool = False
_device: torch.device | None = None
@property
def lowvram_patch_counter(self) -> int:
...
def device(self) -> torch.device:
if isinstance(self.model, DeviceSettable):
return self.model.device
else:
return self._device
@device.setter
def device(self, value: torch.device):
if isinstance(self.model, DeviceSettable):
self.model.device = value
self._device = value

View File

@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import collections
import copy
import inspect
import logging
@ -5,10 +23,12 @@ import uuid
from typing import Optional
import torch
import torch.nn
from . import model_management
from . import utils
from .model_management_types import ModelManageable
from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements
from .types import UnetWrapperFunction
@ -69,10 +89,27 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
return model_options
def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False, ckpt_name: Optional[str] = None):
def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size
self.model = model
self.model: torch.nn.Module = model
self.patches = {}
self.backup = {}
self.object_patches = {}
@ -81,25 +118,21 @@ class ModelPatcher(ModelManageable):
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self._current_device: torch.device
if current_device is None:
self._current_device = self.offload_device
else:
self._current_device = current_device
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
self.ckpt_name = ckpt_name
self._lowvram_patch_counter = 0
self._memory_measurements = MemoryMeasurements(self.model)
@property
def lowvram_patch_counter(self):
return self._lowvram_patch_counter
def model_device(self) -> torch.device:
return self._memory_measurements.device
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
self._lowvram_patch_counter = value
@model_device.setter
def model_device(self, value: torch.device):
self._memory_measurements.device = value
def lowvram_patch_counter(self):
return self._memory_measurements.lowvram_patch_counter
def model_size(self):
if self.size > 0:
@ -107,8 +140,12 @@ class ModelPatcher(ModelManageable):
self.size = model_management.module_size(self.model)
return self.size
def loaded_size(self):
return self._memory_measurements.model_loaded_weight_memory
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n._memory_measurements = self._memory_measurements
n.ckpt_name = self.ckpt_name
n.patches = {}
for k in self.patches:
@ -122,11 +159,9 @@ class ModelPatcher(ModelManageable):
return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
return hasattr(other, 'model') and self.model is other.model
def clone_has_same_weights(self, clone):
def clone_has_same_weights(self, clone: "ModelPatcher"):
if not self.is_clone(clone):
return False
@ -139,7 +174,8 @@ class ModelPatcher(ModelManageable):
else:
return True
def memory_required(self, input_shape):
def memory_required(self, input_shape) -> int:
assert isinstance(self.model, BaseModel)
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
@ -281,16 +317,16 @@ class ModelPatcher(ModelManageable):
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None):
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches:
return
weight = utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update
inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
@ -319,31 +355,25 @@ class ModelPatcher(ModelManageable):
if device_to is not None:
self.model.to(device_to)
self._current_device = device_to
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = self.model_size()
return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
if m.comfy_cast_weights:
continue
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
@ -365,15 +395,39 @@ class ModelPatcher(ModelManageable):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "comfy_cast_weights"):
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += model_management.module_size(m)
param = list(m.parameters())
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
continue
weight_to = None
if full_load: # TODO
weight_to = device_to
self.patch_weight_to_device(weight_key, device_to=weight_to) # TODO: speed this up without OOM
self.patch_weight_to_device(bias_key, device_to=weight_to)
m.to(device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self._memory_measurements.model_lowvram = True
else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
self._memory_measurements.model_lowvram = False
self._memory_measurements.lowvram_patch_counter += patch_counter
self._memory_measurements.model_loaded_weight_memory = mem_counter
self.model_device = device_to
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
return self.model
def calculate_weight(self, patches, weight, key):
@ -548,31 +602,28 @@ class ModelPatcher(ModelManageable):
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self.model_lowvram:
if self._memory_measurements.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
wipe_lowvram_weight(m)
self.model_lowvram = False
self.lowvram_patch_counter = 0
self._memory_measurements.model_lowvram = False
self._memory_measurements.lowvram_patch_counter = 0
keys = list(self.backup.keys())
if self.weight_inplace_update:
for k in keys:
utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
utils.set_attr_param(self.model, k, self.backup[k])
for k in keys:
bk = self.backup[k]
if bk.inplace_update:
utils.copy_to_param(self.model, k, bk.weight)
else:
utils.set_attr_param(self.model, k, bk.weight)
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self._current_device = value = device_to
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = 0
keys = list(self.object_patches_backup.keys())
for k in keys:
@ -580,9 +631,66 @@ class ModelPatcher(ModelManageable):
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
utils.copy_to_param(self.model, key, bk.weight)
else:
utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
self._memory_measurements.model_lowvram = True
self._memory_measurements.lowvram_patch_counter += patch_counter
self._memory_measurements.model_loaded_weight_memory -= memory_freed
return memory_freed
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False)
full_load = False
if not self._memory_measurements.model_lowvram:
return 0
if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self._memory_measurements.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self._memory_measurements.model_loaded_weight_memory - current_used
def current_loaded_device(self):
return self.model_device
@property
def current_device(self) -> torch.device:
return self._current_device
return self.current_loaded_device()
def __str__(self):
if self.ckpt_name is not None:

View File

@ -823,14 +823,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype):
dtype = None
model_options = {}
if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2
model_options["dtype"] = torch.float8_e5m2
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
model = sd.load_unet(unet_path, dtype=dtype)
model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
class CLIPLoader:

View File

@ -91,7 +91,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
for (duration, module_name, success, new_nodes) in sorted(timings):
logging.info(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
if raise_on_failure and len(exceptions) > 0:
try:
raise ExceptionGroup("Node import failed", exceptions)

View File

@ -20,21 +20,19 @@ from . import model_sampling
from . import sd1_clip
from . import sdxl_clip
from . import utils
from .model_management import load_models_gpu
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import hydit
from .text_encoders import sa_t5
from .text_encoders import aura_t5
from .text_encoders import flux
from .ldm.audio.autoencoder import AudioOobleckVAE
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .model_management import load_models_gpu
from .t2i_adapter import adapter
from .taesd import taesd
from .text_encoders import aura_t5
from .text_encoders import flux
from .text_encoders import hydit
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
@ -68,7 +66,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target: CLIPTarget=None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None=None):
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0):
if tokenizer_data is None:
tokenizer_data = dict()
if no_init:
@ -79,9 +77,9 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config
@ -90,11 +88,16 @@ class CLIP:
for dt in self.cond_stage_model.dtypes:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
if params['device'] != offload_device:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
def clone(self):
n = CLIP(no_init=True)
@ -476,7 +479,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sd3_clip.SD3ClipModel
clip_target.tokenizer = sd3_clip.SD3Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config)
parameters = 0
for c in clip_data:
parameters += utils.calculate_parameters(c)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@ -523,15 +530,21 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, ckpt_path=ckpt_path)
if out is None:
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, ckpt_path: str | None = None):
clip = None
clipvision = None
vae = None
model = None
_model_patcher = None
clip_target = None
inital_load_device = None
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = utils.calculate_parameters(sd, diffusion_model_prefix)
@ -540,13 +553,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return None
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("weight_dtype", None)
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@ -570,7 +588,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
parameters = utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
@ -589,14 +608,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.debug("left over keys: {}".format(left_over))
if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device, ckpt_name=os.path.basename(ckpt_path))
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), ckpt_name=os.path.basename(ckpt_path))
if inital_load_device != torch.device("cpu"):
load_models_gpu([_model_patcher])
model_management.load_models_gpu([_model_patcher], force_full_load=True)
return (_model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options: dict = None): # load unet in diffusers or regular format
if model_options is None:
model_options = {}
dtype = model_options.get("dtype", None)
# Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@ -638,6 +660,7 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
@ -647,15 +670,27 @@ def load_unet_state_dict(sd, dtype=None): # load unet in diffusers or regular f
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path, dtype=None):
def load_diffusion_model(unet_path, model_options: dict = None):
if model_options is None:
model_options = {}
sd = utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd, dtype=dtype)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def load_unet(unet_path, dtype=None):
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
def load_unet_state_dict(sd, dtype=None):
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None
load_models = [model]

View File

@ -16,7 +16,7 @@ except ImportError:
from typing import Tuple, Sequence, TypeVar, Callable
import torch
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model
from . import model_management
@ -66,7 +66,7 @@ class ClipTokenWeightEncoder:
output = []
for k in range(0, sections):
z = out[k:k+1]
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
@ -112,7 +112,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
self.operations = ops.manual_cast
self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers
@ -389,6 +388,18 @@ def expand_directory_list(directories):
return list(dirs)
def bundled_embed(embed, prefix, suffix): # bundled embedding in lora format
i = 0
out_list = []
for k in embed:
if k.startswith(prefix) and k.endswith(suffix):
out_list.append(embed[k])
if len(out_list) == 0:
return None
return torch.cat(out_list, dim=0)
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
@ -455,8 +466,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
elif embed_key is not None and embed_key in embed:
embed_out = embed[embed_key]
else:
values = embed.values()
embed_out = next(iter(values))
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
if embed_out is None:
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
if embed_out is None:
values = embed.values()
embed_out = next(iter(values))
return embed_out
@ -631,6 +646,7 @@ class SDTokenizer:
def state_dict(self):
return {}
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
@ -664,6 +680,7 @@ class SD1Tokenizer:
def state_dict(self):
return {}
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
super().__init__()

View File

@ -640,9 +640,9 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.6
memory_usage_factor = 2.8
supported_inference_dtypes = [torch.bfloat16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]

View File

@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from . import model_base
from . import utils
@ -30,6 +48,7 @@ class BASE:
memory_usage_factor = 2.0
manual_cast_dtype = None
custom_operations = None
@classmethod
def matches(s, unet_config, state_dict=None):

View File

@ -1,3 +1,20 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import contextlib
@ -472,8 +489,33 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
k = "{}.attn.".format(prefix_from)
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}
for k in block_map:
@ -489,15 +531,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
block_map = {#TODO
block_map = {
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
MAP_BASIC = { #TODO
MAP_BASIC = {
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.linear.weight", "proj_out.weight"),
("img_in.bias", "x_embedder.bias"),
("img_in.weight", "x_embedder.weight"),
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
("txt_in.bias", "context_embedder.bias"),
("txt_in.weight", "context_embedder.weight"),
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
}
for k in MAP_BASIC:
@ -859,6 +927,8 @@ def seed_for_block(seed):
numpy_rng_state = np.random.get_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all()
else:
cuda_rng_state = None
# Set the new seed
torch.manual_seed(seed)

View File

@ -19,6 +19,7 @@ class CLIPTextEncodeHunyuanDiT:
cond = output.pop("cond")
return ([[cond, output]], )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
}

View File

@ -11,9 +11,10 @@ from typing import Any, Dict, Optional, List, Callable, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel, AutoModelForCausalLM
from typing_extensions import TypedDict
from comfy import model_management
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel
@ -28,9 +29,9 @@ _AUTO_CHAT_TEMPLATE = "default"
try:
from llava import model
logging.info("Additional LLaVA models are now supported")
logging.debug("Additional LLaVA models are now supported")
except ImportError as exc:
logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
# aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any]
@ -129,7 +130,7 @@ class TransformersGenerationConfig(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",)
"model": ("MODEL", {})
}
}
@ -247,13 +248,22 @@ class TransformersLoader(CustomNode):
**hub_kwargs
}
try:
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
except Exception as exc_info:
# not yet supported by automodel
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
# try:
# import flash_attn
# from_pretrained_kwargs["attn_implementation"] = "flash_attention_2"
# except ImportError:
# logging.debug("install flash_attn for improved performance using language nodes")
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
if config_dict["model_type"] == "llava_next":
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
else:
try:
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
except Exception:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs)
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
@ -265,6 +275,10 @@ class TransformersLoader(CustomNode):
processor = None
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
logging.debug("enabled xformers memory efficient attention")
model_managed = TransformersManagedModel(
repo_id=ckpt_name,
model=model,

View File

@ -263,6 +263,7 @@ class CLIPSave:
metadata = {}
if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:

View File

@ -161,7 +161,7 @@ class BooleanRequestParameter(CustomNode):
}
}
RETURN_TYPES = ("STRING",)
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "execute"
CATEGORY = "api/openapi"

View File

@ -108,3 +108,8 @@ NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
}

View File

@ -16,7 +16,7 @@ try:
from spandrel import MAIN_REGISTRY
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
logging.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
except:
pass
@ -26,20 +26,16 @@ class UpscaleModelManageable(ModelManageable):
self.ckpt_name = ckpt_name
self.model_descriptor = model_descriptor
self.model = model_descriptor.model
self.load_device = model_management.unet_offload_device()
self.load_device = model_management.get_torch_device()
self.offload_device = model_management.unet_offload_device()
self._current_device = self.offload_device
self._lowvram_patch_counter = 0
# Private properties for image sizes and channels
self._input_size = (1, 512, 512) # Default input size (batch, height, width)
self._input_size = (1, 512, 512)
self._input_channels = model_descriptor.input_channels
self._output_channels = model_descriptor.output_channels
self.tile = 512
@property
def current_device(self) -> torch.device:
return self._current_device
return self.model_descriptor.device
@property
def input_size(self) -> tuple[int, int, int]:
@ -65,21 +61,14 @@ class UpscaleModelManageable(ModelManageable):
def is_clone(self, other: Any) -> bool:
return isinstance(other, UpscaleModelManageable) and self.model is other.model
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
def clone_has_same_weights(self, clone) -> bool:
return self.is_clone(clone)
def model_size(self) -> int:
# Calculate the size of the model parameters
model_params_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
# Get the byte size of the model's dtype
dtype_size = torch.finfo(self.model_dtype()).bits // 8
# Calculate the memory required for input and output images
input_size = self._input_size[0] * min(self.tile, self._input_size[1]) * min(self.tile, self._input_size[2]) * self._input_channels * dtype_size
output_size = self.output_size[0] * self.output_size[1] * self.output_size[2] * self._output_channels * dtype_size
# Add some extra memory for processing
extra_memory = (input_size + output_size) * 2 # This is an estimate, adjust as needed
return model_params_size + input_size + output_size + extra_memory
@ -95,30 +84,16 @@ class UpscaleModelManageable(ModelManageable):
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=device_to)
self._current_device = device_to
self._lowvram_patch_counter += 1
return self.model
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
if patch_weights:
self.model.to(device=device_to)
self._current_device = device_to
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
self.model.to(device=device_to)
return self.model
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
if unpatch_weights:
self.model.to(device=offload_device)
self._current_device = offload_device
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=offload_device)
return self.model
@property
def lowvram_patch_counter(self) -> int:
return self._lowvram_patch_counter
@lowvram_patch_counter.setter
def lowvram_patch_counter(self, value: int):
self._lowvram_patch_counter = value
def __str__(self):
if self.ckpt_name is not None:
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"

View File

@ -46,7 +46,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@pytest.mark.asyncio
def test_known_repos(tmp_path_factory):
async def test_known_repos(tmp_path_factory):
prev_hub_cache = os.getenv("HF_HUB_CACHE")
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))

View File

@ -39,12 +39,17 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
workflow = json.loads(workflow_file.read_text())
prompt = Prompt.validate(workflow)
# todo: add all the models we want to test a bit more elegantly
# todo: add all the models we want to test a bit m2ore elegantly
outputs = await client.queue_prompt(prompt)
if any(v.class_type == "SaveImage" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
assert outputs[save_image_node_id]["audio"][0]["filename"] is not None
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
elif any(v.class_type == "PreviewString" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
output_str = outputs[save_image_node_id]["string"][0]
assert output_str is not None
assert len(output_str) > 0

View File

@ -0,0 +1,53 @@
{
"17": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"19": {
"inputs": {
"model_name": "RealESRGAN_x4plus.pth"
},
"class_type": "UpscaleModelLoader",
"_meta": {
"title": "Load Upscale Model"
}
},
"20": {
"inputs": {
"upscale_model": [
"19",
0
],
"image": [
"17",
0
]
},
"class_type": "ImageUpscaleWithModel",
"_meta": {
"title": "Upscale Image (using Model)"
}
},
"21": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"20",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}

View File

@ -0,0 +1,91 @@
{
"1": {
"inputs": {
"ckpt_name": "llava-hf/llava-v1.6-mistral-7b-hf",
"subfolder": ""
},
"class_type": "TransformersLoader",
"_meta": {
"title": "TransformersLoader"
}
},
"3": {
"inputs": {
"max_new_tokens": 512,
"repetition_penalty": 0,
"seed": 2013744903,
"use_cache": true,
"__tokens": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. </s>",
"model": [
"1",
0
],
"tokens": [
"4",
0
]
},
"class_type": "TransformersGenerate",
"_meta": {
"title": "TransformersGenerate"
}
},
"4": {
"inputs": {
"prompt": "Who is this?",
"chat_template": "llava-v1.6-mistral-7b-hf",
"model": [
"1",
0
],
"images": [
"8",
0
]
},
"class_type": "OneShotInstructTokenize",
"_meta": {
"title": "OneShotInstructTokenize"
}
},
"5": {
"inputs": {
"value": [
"3",
0
],
"output": "\n\nThis is a black and white sketch of a woman. The image is stylized and does not provide enough detail to identify the specific person being depicted. It appears to be a portrait with a focus on the facial features and the hair, which is styled in a way that suggests it might be from a historical or classical period. The style of the drawing is reminiscent of the works of artists who specialize in portraiture, such as those from the Renaissance or the 19th century. "
},
"class_type": "PreviewString",
"_meta": {
"title": "PreviewString"
}
},
"6": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"8": {
"inputs": {
"upscale_method": "nearest-exact",
"megapixels": 1,
"image": [
"6",
0
]
},
"class_type": "ImageScaleToTotalPixels",
"_meta": {
"title": "ImageScaleToTotalPixels"
}
}
}

View File

@ -0,0 +1,60 @@
{
"1": {
"inputs": {
"ckpt_name": "microsoft/Phi-3-mini-4k-instruct",
"subfolder": ""
},
"class_type": "TransformersLoader",
"_meta": {
"title": "TransformersLoader"
}
},
"3": {
"inputs": {
"max_new_tokens": 512,
"repetition_penalty": 0,
"seed": 2514389986,
"use_cache": true,
"__tokens": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple.<|end|>",
"model": [
"1",
0
],
"tokens": [
"4",
0
]
},
"class_type": "TransformersGenerate",
"_meta": {
"title": "TransformersGenerate"
}
},
"4": {
"inputs": {
"prompt": "What comes after apple?",
"chat_template": "phi-3",
"model": [
"1",
0
]
},
"class_type": "OneShotInstructTokenize",
"_meta": {
"title": "OneShotInstructTokenize"
}
},
"5": {
"inputs": {
"value": [
"3",
0
],
"output": "The question \"What comes after apple?\" can be interpreted in a few ways. If we're discussing the alphabetical sequence, the letter that comes after 'A' (for apple) is 'B'. If we're discussing a sequence of fruits, it could be any fruit that follows apple in a particular list or context. For example, in a list of fruits, banana might come after apple."
},
"class_type": "PreviewString",
"_meta": {
"title": "PreviewString"
}
}
}

View File

@ -0,0 +1,223 @@
{
"1": {
"inputs": {
"ckpt_name": "sd_xl_base_1.0.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"2": {
"inputs": {
"strength": 0.5,
"start_percent": 0,
"end_percent": 1,
"positive": [
"3",
0
],
"negative": [
"6",
0
],
"control_net": [
"7",
0
],
"image": [
"9",
0
]
},
"class_type": "ControlNetApplyAdvanced",
"_meta": {
"title": "Apply ControlNet (Advanced)"
}
},
"3": {
"inputs": {
"text": "a girl with blue hair",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"5": {
"inputs": {
"add_noise": true,
"noise_seed": 969970429360105,
"cfg": 8,
"model": [
"1",
0
],
"positive": [
"2",
0
],
"negative": [
"2",
1
],
"sampler": [
"13",
0
],
"sigmas": [
"11",
0
],
"latent_image": [
"12",
0
]
},
"class_type": "SamplerCustom",
"_meta": {
"title": "SamplerCustom"
}
},
"6": {
"inputs": {
"text": "",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"type": "canny/lineart/anime_lineart/mlsd",
"control_net": [
"8",
0
]
},
"class_type": "SetUnionControlNetType",
"_meta": {
"title": "SetUnionControlNetType"
}
},
"8": {
"inputs": {
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
},
"class_type": "ControlNetLoader",
"_meta": {
"title": "Load ControlNet Model"
}
},
"9": {
"inputs": {
"low_threshold": 0.4,
"high_threshold": 0.8,
"image": [
"18",
0
]
},
"class_type": "Canny",
"_meta": {
"title": "Canny"
}
},
"11": {
"inputs": {
"model_type": "SDXL",
"steps": 25,
"denoise": 1
},
"class_type": "AlignYourStepsScheduler",
"_meta": {
"title": "AlignYourStepsScheduler"
}
},
"12": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"13": {
"inputs": {
"eta": 1,
"s_noise": 1
},
"class_type": "SamplerEulerAncestral",
"_meta": {
"title": "SamplerEulerAncestral"
}
},
"14": {
"inputs": {
"samples": [
"5",
0
],
"vae": [
"1",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"15": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"14",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"17": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"18": {
"inputs": {
"upscale_method": "lanczos",
"megapixels": 1,
"image": [
"17",
0
]
},
"class_type": "ImageScaleToTotalPixels",
"_meta": {
"title": "ImageScaleToTotalPixels"
}
}
}

View File

@ -0,0 +1,302 @@
{
"1": {
"inputs": {
"ckpt_name": "sd_xl_base_1.0.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"2": {
"inputs": {
"strength": 0.5,
"start_percent": 0,
"end_percent": 1,
"positive": [
"3",
0
],
"negative": [
"6",
0
],
"control_net": [
"28",
0
],
"image": [
"9",
0
]
},
"class_type": "ControlNetApplyAdvanced",
"_meta": {
"title": "Apply ControlNet (Advanced)"
}
},
"3": {
"inputs": {
"text": [
"26",
0
],
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"5": {
"inputs": {
"add_noise": [
"23",
0
],
"noise_seed": [
"20",
0
],
"cfg": [
"19",
0
],
"model": [
"1",
0
],
"positive": [
"2",
0
],
"negative": [
"2",
1
],
"sampler": [
"24",
0
],
"sigmas": [
"11",
0
],
"latent_image": [
"12",
0
]
},
"class_type": "SamplerCustom",
"_meta": {
"title": "SamplerCustom"
}
},
"6": {
"inputs": {
"text": "",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"9": {
"inputs": {
"low_threshold": 0.4,
"high_threshold": 0.8,
"image": [
"18",
0
]
},
"class_type": "Canny",
"_meta": {
"title": "Canny"
}
},
"11": {
"inputs": {
"model_type": "SDXL",
"steps": 25,
"denoise": 1
},
"class_type": "AlignYourStepsScheduler",
"_meta": {
"title": "AlignYourStepsScheduler"
}
},
"12": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"14": {
"inputs": {
"samples": [
"5",
0
],
"vae": [
"1",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"15": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"14",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"17": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"18": {
"inputs": {
"upscale_method": "lanczos",
"megapixels": 1,
"image": [
"17",
0
]
},
"class_type": "ImageScaleToTotalPixels",
"_meta": {
"title": "ImageScaleToTotalPixels"
}
},
"19": {
"inputs": {
"value": 8,
"name": "cfg",
"title": "",
"description": "",
"__required": true
},
"class_type": "FloatRequestParameter",
"_meta": {
"title": "FloatRequestParameter"
}
},
"20": {
"inputs": {
"value": 0,
"name": "seed",
"title": "",
"description": "",
"__required": true
},
"class_type": "IntRequestParameter",
"_meta": {
"title": "IntRequestParameter"
}
},
"23": {
"inputs": {
"value": true,
"name": "add_noise",
"title": "",
"description": "",
"__required": true
},
"class_type": "BooleanRequestParameter",
"_meta": {
"title": "BooleanRequestParameter"
}
},
"24": {
"inputs": {
"sampler_name": [
"25",
0
]
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"25": {
"inputs": {
"value": "euler",
"name": "sampler_name",
"title": "",
"description": "",
"__required": true
},
"class_type": "StringEnumRequestParameter",
"_meta": {
"title": "StringEnumRequestParameter"
}
},
"26": {
"inputs": {
"value": "a girl with blue hair",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "StringRequestParameter",
"_meta": {
"title": "StringRequestParameter"
}
},
"27": {
"inputs": {
"control_net_name": "xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
},
"class_type": "ControlNetLoader",
"_meta": {
"title": "Load ControlNet Model"
}
},
"28": {
"inputs": {
"type": "canny/lineart/anime_lineart/mlsd",
"control_net": [
"27",
0
]
},
"class_type": "SetUnionControlNetType",
"_meta": {
"title": "SetUnionControlNetType"
}
}
}