mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
ac153699fe
@ -45,6 +45,8 @@ jobs:
|
|||||||
sed -i '1i../ComfyUI' ./python310._pth
|
sed -i '1i../ComfyUI' ./python310._pth
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
|
git clone https://github.com/comfyanonymous/taesd
|
||||||
|
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable
|
mkdir ComfyUI_windows_portable
|
||||||
mv python_embeded ComfyUI_windows_portable
|
mv python_embeded ComfyUI_windows_portable
|
||||||
@ -59,7 +61,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
|
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable
|
cd ComfyUI_windows_portable
|
||||||
|
|||||||
@ -31,12 +31,14 @@ jobs:
|
|||||||
echo 'import site' >> ./python311._pth
|
echo 'import site' >> ./python311._pth
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||||
ls ../temp_wheel_dir
|
ls ../temp_wheel_dir
|
||||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||||
sed -i '1i../ComfyUI' ./python311._pth
|
sed -i '1i../ComfyUI' ./python311._pth
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
|
git clone https://github.com/comfyanonymous/taesd
|
||||||
|
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||||
|
|
||||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||||
@ -52,7 +54,7 @@ jobs:
|
|||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||||
|
|
||||||
cd ComfyUI_windows_portable_nightly_pytorch
|
cd ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
@ -29,7 +29,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
|
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||||
|
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: will never download anything.
|
||||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||||
@ -69,7 +70,7 @@ There is a portable standalone build for Windows that should work for running on
|
|||||||
|
|
||||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
|
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
|
||||||
|
|
||||||
Just download, extract and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
#### How do I share models between another UI and ComfyUI?
|
#### How do I share models between another UI and ComfyUI?
|
||||||
|
|
||||||
@ -193,7 +194,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,15 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
|
|||||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
|
||||||
|
fp_group = parser.add_mutually_exclusive_group()
|
||||||
|
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
|
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
|
||||||
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
|
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.")
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
class LatentPreviewMethod(enum.Enum):
|
||||||
@ -53,7 +61,8 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
"num_attention_heads": 20,
|
"num_attention_heads": 20,
|
||||||
"num_hidden_layers": 32,
|
"num_hidden_layers": 32,
|
||||||
"pad_token_id": 1,
|
"pad_token_id": 1,
|
||||||
"projection_dim": 512,
|
"projection_dim": 1280,
|
||||||
"torch_dtype": "float32",
|
"torch_dtype": "float32",
|
||||||
"vocab_size": 49408
|
"vocab_size": 49408
|
||||||
}
|
}
|
||||||
|
|||||||
@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict):
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
capture_qkv_weight = {}
|
capture_qkv_weight = {}
|
||||||
capture_qkv_bias = {}
|
capture_qkv_bias = {}
|
||||||
for k, v in text_enc_dict.items():
|
for k, v in text_enc_dict.items():
|
||||||
|
if not k.startswith(prefix):
|
||||||
|
continue
|
||||||
if (
|
if (
|
||||||
k.endswith(".self_attn.q_proj.weight")
|
k.endswith(".self_attn.q_proj.weight")
|
||||||
or k.endswith(".self_attn.k_proj.weight")
|
or k.endswith(".self_attn.k_proj.weight")
|
||||||
|
|||||||
@ -3,12 +3,13 @@ import os
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
from comfy.sd import load_checkpoint
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
import diffusers_convert
|
from . import diffusers_convert
|
||||||
|
|
||||||
|
|
||||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||||
|
|||||||
@ -215,10 +215,12 @@ class PositionNet(nn.Module):
|
|||||||
|
|
||||||
def forward(self, boxes, masks, positive_embeddings):
|
def forward(self, boxes, masks, positive_embeddings):
|
||||||
B, N, _ = boxes.shape
|
B, N, _ = boxes.shape
|
||||||
masks = masks.unsqueeze(-1)
|
dtype = self.linears[0].weight.dtype
|
||||||
|
masks = masks.unsqueeze(-1).to(dtype)
|
||||||
|
positive_embeddings = positive_embeddings.to(dtype)
|
||||||
|
|
||||||
# embedding position (it may includes padding as placeholder)
|
# embedding position (it may includes padding as placeholder)
|
||||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
||||||
|
|
||||||
# learnable null embedding
|
# learnable null embedding
|
||||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||||
@ -252,7 +254,8 @@ class Gligen(nn.Module):
|
|||||||
|
|
||||||
if self.lowvram == True:
|
if self.lowvram == True:
|
||||||
self.position_net.cpu()
|
self.position_net.cpu()
|
||||||
def func_lowvram(key, x):
|
def func_lowvram(x, extra_options):
|
||||||
|
key = extra_options["transformer_index"]
|
||||||
module = self.module_list[key]
|
module = self.module_list[key]
|
||||||
module.to(x.device)
|
module.to(x.device)
|
||||||
r = module(x, objs)
|
r = module(x, objs)
|
||||||
|
|||||||
@ -66,6 +66,9 @@ class BatchedBrownianTree:
|
|||||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||||
|
|
||||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||||
|
self.cpu_tree = True
|
||||||
|
if "cpu" in kwargs:
|
||||||
|
self.cpu_tree = kwargs.pop("cpu")
|
||||||
t0, t1, self.sign = self.sort(t0, t1)
|
t0, t1, self.sign = self.sort(t0, t1)
|
||||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||||
if seed is None:
|
if seed is None:
|
||||||
@ -77,7 +80,10 @@ class BatchedBrownianTree:
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
seed = [seed]
|
seed = [seed]
|
||||||
self.batched = False
|
self.batched = False
|
||||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
if self.cpu_tree:
|
||||||
|
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||||
|
else:
|
||||||
|
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sort(a, b):
|
def sort(a, b):
|
||||||
@ -85,7 +91,11 @@ class BatchedBrownianTree:
|
|||||||
|
|
||||||
def __call__(self, t0, t1):
|
def __call__(self, t0, t1):
|
||||||
t0, t1, sign = self.sort(t0, t1)
|
t0, t1, sign = self.sort(t0, t1)
|
||||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
if self.cpu_tree:
|
||||||
|
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||||
|
else:
|
||||||
|
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||||
|
|
||||||
return w if self.batched else w[0]
|
return w if self.batched else w[0]
|
||||||
|
|
||||||
|
|
||||||
@ -104,10 +114,10 @@ class BrownianTreeNoiseSampler:
|
|||||||
internal timestep.
|
internal timestep.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
||||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
||||||
|
|
||||||
def __call__(self, sigma, sigma_next):
|
def __call__(self, sigma, sigma_next):
|
||||||
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
||||||
@ -543,7 +553,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
@ -613,8 +624,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if solver_type not in {'heun', 'midpoint'}:
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
@ -649,3 +661,18 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
h_last = h
|
h_last = h
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -9,8 +9,23 @@ class LatentFormat:
|
|||||||
class SD15(LatentFormat):
|
class SD15(LatentFormat):
|
||||||
def __init__(self, scale_factor=0.18215):
|
def __init__(self, scale_factor=0.18215):
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
self.latent_rgb_factors = [
|
||||||
|
# R G B
|
||||||
|
[0.298, 0.207, 0.208], # L1
|
||||||
|
[0.187, 0.286, 0.173], # L2
|
||||||
|
[-0.158, 0.189, 0.264], # L3
|
||||||
|
[-0.184, -0.271, -0.473], # L4
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesd_decoder.pth"
|
||||||
|
|
||||||
class SDXL(LatentFormat):
|
class SDXL(LatentFormat):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.13025
|
self.scale_factor = 0.13025
|
||||||
|
self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL
|
||||||
|
# R G B
|
||||||
|
[0.298, 0.207, 0.208], # L1
|
||||||
|
[0.187, 0.286, 0.173], # L2
|
||||||
|
[-0.158, 0.189, 0.264], # L3
|
||||||
|
[-0.184, -0.271, -0.473], # L4
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
||||||
|
|||||||
@ -180,6 +180,12 @@ class DDIMSampler(object):
|
|||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x_start)
|
||||||
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ddim_sampling(self, cond, shape,
|
def ddim_sampling(self, cond, shape,
|
||||||
x_T=None, ddim_use_original_steps=False,
|
x_T=None, ddim_use_original_steps=False,
|
||||||
@ -214,7 +220,7 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert x0 is not None
|
assert x0 is not None
|
||||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
img = img_orig * mask + (1. - mask) * img
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
if ucg_schedule is not None:
|
if ucg_schedule is not None:
|
||||||
|
|||||||
@ -16,11 +16,14 @@ if model_management.xformers_enabled():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
# CrossAttn precision handling
|
|
||||||
import os
|
|
||||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
# CrossAttn precision handling
|
||||||
|
if args.dont_upcast_attention:
|
||||||
|
print("disabling upcasting of attention")
|
||||||
|
_ATTN_PRECISION = "fp16"
|
||||||
|
else:
|
||||||
|
_ATTN_PRECISION = "fp32"
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
@ -275,7 +278,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
@ -311,7 +314,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||||
first_op_done = True
|
first_op_done = True
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
|||||||
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels, dtype=dtype),
|
nn.GroupNorm(32, channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||||
)
|
)
|
||||||
@ -244,7 +244,7 @@ class ResBlock(TimestepBlock):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
normalization(self.out_channels, dtype=dtype),
|
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
@ -778,13 +778,13 @@ class UNetModel(nn.Module):
|
|||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch, dtype=self.dtype),
|
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
normalization(ch),
|
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||||
conv_nd(dims, model_channels, n_embed, 1),
|
conv_nd(dims, model_channels, n_embed, 1),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
@ -821,7 +821,7 @@ class UNetModel(nn.Module):
|
|||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
|
|||||||
@ -84,7 +84,7 @@ def _summarize_chunk(
|
|||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
max_score = max_score.detach()
|
max_score = max_score.detach()
|
||||||
torch.exp(attn_weights - max_score, out=attn_weights)
|
torch.exp(attn_weights - max_score, out=attn_weights)
|
||||||
exp_weights = attn_weights
|
exp_weights = attn_weights.to(value.dtype)
|
||||||
exp_values = torch.bmm(exp_weights, value)
|
exp_values = torch.bmm(exp_weights, value)
|
||||||
max_score = max_score.squeeze(-1)
|
max_score = max_score.squeeze(-1)
|
||||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||||
@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
attn_scores /= summed
|
attn_scores /= summed
|
||||||
attn_probs = attn_scores
|
attn_probs = attn_scores
|
||||||
|
|
||||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||||
return hidden_states_slice
|
return hidden_states_slice
|
||||||
|
|
||||||
class ScannedChunk(NamedTuple):
|
class ScannedChunk(NamedTuple):
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from . import utils
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
self.diffusion_model = UNetModel(**unet_config)
|
self.diffusion_model = UNetModel(**unet_config)
|
||||||
self.v_prediction = v_prediction
|
self.v_prediction = v_prediction
|
||||||
@ -50,7 +52,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
xc = x
|
xc = x
|
||||||
context = torch.cat(c_crossattn, 1)
|
context = torch.cat(c_crossattn, 1)
|
||||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
|
dtype = self.get_dtype()
|
||||||
|
xc = xc.to(dtype)
|
||||||
|
t = t.to(dtype)
|
||||||
|
context = context.to(dtype)
|
||||||
|
if c_adm is not None:
|
||||||
|
c_adm = c_adm.to(dtype)
|
||||||
|
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
@ -83,6 +91,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
def process_latent_out(self, latent):
|
def process_latent_out(self, latent):
|
||||||
return self.latent_format.process_out(latent)
|
return self.latent_format.process_out(latent)
|
||||||
|
|
||||||
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||||
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||||
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||||
|
if self.get_dtype() == torch.float16:
|
||||||
|
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
||||||
|
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
||||||
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||||
|
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
||||||
@ -144,10 +162,10 @@ class SDXLRefiner(BaseModel):
|
|||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out))[None, ]
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
@ -168,11 +186,11 @@ class SDXL(BaseModel):
|
|||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
out.append(self.embedder(torch.Tensor([target_width])))
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([target_height])))
|
out.append(self.embedder(torch.Tensor([target_height])))
|
||||||
|
out.append(self.embedder(torch.Tensor([target_width])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out))[None, ]
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|||||||
@ -16,13 +16,11 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
num_res_blocks = 2
|
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
"out_channels": 4,
|
"out_channels": 4,
|
||||||
"num_res_blocks": num_res_blocks,
|
|
||||||
"use_spatial_transformer": True,
|
"use_spatial_transformer": True,
|
||||||
"legacy": False
|
"legacy": False
|
||||||
}
|
}
|
||||||
@ -110,11 +108,13 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
|||||||
unet_config["context_dim"] = context_dim
|
unet_config["context_dim"] = context_dim
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
|
def model_config_from_unet_config(unet_config):
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
|
||||||
for model_config in supported_models.models:
|
for model_config in supported_models.models:
|
||||||
if model_config.matches(unet_config):
|
if model_config.matches(unet_config):
|
||||||
return model_config(unet_config)
|
return model_config(unet_config)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||||
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||||
|
return model_config_from_unet_config(unet_config)
|
||||||
|
|||||||
@ -139,7 +139,23 @@ else:
|
|||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
|
def is_nvidia():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.cuda:
|
||||||
|
return True
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
|
|
||||||
|
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
|
try:
|
||||||
|
if is_nvidia():
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
if int(torch_version[0]) >= 2:
|
||||||
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
@ -155,10 +171,15 @@ elif args.highvram or args.gpu_only:
|
|||||||
vram_state = VRAMState.HIGH_VRAM
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
FORCE_FP32 = False
|
FORCE_FP32 = False
|
||||||
|
FORCE_FP16 = False
|
||||||
if args.force_fp32:
|
if args.force_fp32:
|
||||||
print("Forcing FP32, if this improves things please report it.")
|
print("Forcing FP32, if this improves things please report it.")
|
||||||
FORCE_FP32 = True
|
FORCE_FP32 = True
|
||||||
|
|
||||||
|
if args.force_fp16:
|
||||||
|
print("Forcing FP16.")
|
||||||
|
FORCE_FP16 = True
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
try:
|
try:
|
||||||
import accelerate
|
import accelerate
|
||||||
@ -212,10 +233,9 @@ def unload_model():
|
|||||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
#never unload models from GPU on high vram
|
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
current_loaded_model.model.to(current_loaded_model.offload_device)
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
||||||
current_loaded_model.model_patches_to("cpu")
|
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
@ -225,6 +245,8 @@ def unload_model():
|
|||||||
n.cpu()
|
n.cpu()
|
||||||
current_gpu_controlnets = []
|
current_gpu_controlnets = []
|
||||||
|
|
||||||
|
def minimum_inference_memory():
|
||||||
|
return (768 * 1024 * 1024)
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
@ -240,15 +262,20 @@ def load_model_gpu(model):
|
|||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
torch_dev = get_torch_device()
|
torch_dev = model.load_device
|
||||||
model.model_patches_to(torch_dev)
|
model.model_patches_to(torch_dev)
|
||||||
|
model.model_patches_to(model.model_dtype())
|
||||||
|
|
||||||
|
if is_device_cpu(torch_dev):
|
||||||
|
vram_set_state = VRAMState.DISABLED
|
||||||
|
else:
|
||||||
|
vram_set_state = vram_state
|
||||||
|
|
||||||
vram_set_state = vram_state
|
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
model_size = model.model_size()
|
model_size = model.model_size()
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||||
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
|
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
@ -257,14 +284,14 @@ def load_model_gpu(model):
|
|||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(get_torch_device())
|
real_model.to(torch_dev)
|
||||||
else:
|
else:
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
@ -307,12 +334,46 @@ def unload_if_low_vram(model):
|
|||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def text_encoder_device():
|
def unet_offload_device():
|
||||||
|
if vram_state == VRAMState.HIGH_VRAM:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_offload_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_device():
|
||||||
|
if args.gpu_only:
|
||||||
|
return get_torch_device()
|
||||||
|
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||||
|
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def vae_device():
|
||||||
|
return get_torch_device()
|
||||||
|
|
||||||
|
def vae_offload_device():
|
||||||
|
if args.gpu_only:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def vae_dtype():
|
||||||
|
if args.fp16_vae:
|
||||||
|
return torch.float16
|
||||||
|
elif args.bf16_vae:
|
||||||
|
return torch.bfloat16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
@ -347,7 +408,7 @@ def pytorch_attention_flash_attention():
|
|||||||
global ENABLE_PYTORCH_ATTENTION
|
global ENABLE_PYTORCH_ATTENTION
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
#TODO: more reliable way of checking for flash attention?
|
#TODO: more reliable way of checking for flash attention?
|
||||||
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -402,10 +463,29 @@ def mps_mode():
|
|||||||
global cpu_state
|
global cpu_state
|
||||||
return cpu_state == CPUState.MPS
|
return cpu_state == CPUState.MPS
|
||||||
|
|
||||||
def should_use_fp16():
|
def is_device_cpu(device):
|
||||||
|
if hasattr(device, 'type'):
|
||||||
|
if (device.type == 'cpu'):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_device_mps(device):
|
||||||
|
if hasattr(device, 'type'):
|
||||||
|
if (device.type == 'mps'):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def should_use_fp16(device=None, model_params=0):
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|
||||||
|
if FORCE_FP16:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if device is not None: #TODO
|
||||||
|
if is_device_cpu(device) or is_device_mps(device):
|
||||||
|
return False
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -419,10 +499,27 @@ def should_use_fp16():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties("cuda")
|
||||||
|
if props.major < 6:
|
||||||
|
return False
|
||||||
|
|
||||||
|
fp16_works = False
|
||||||
|
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||||
|
#when the model doesn't actually fit on the card
|
||||||
|
#TODO: actually test if GP106 and others have the same type of behavior
|
||||||
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
|
||||||
|
for x in nvidia_10_series:
|
||||||
|
if x in props.name.lower():
|
||||||
|
fp16_works = True
|
||||||
|
|
||||||
|
if fp16_works:
|
||||||
|
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||||
|
if model_params * 4 > free_model_memory:
|
||||||
|
return True
|
||||||
|
|
||||||
if props.major < 7:
|
if props.major < 7:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
#FP32 is faster on those cards?
|
#FP16 is just broken on these cards
|
||||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
|
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
|
||||||
for x in nvidia_16_series:
|
for x in nvidia_16_series:
|
||||||
if x in props.name:
|
if x in props.name:
|
||||||
@ -438,7 +535,7 @@ def soft_empty_cache():
|
|||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|||||||
@ -51,11 +51,11 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models += [c[1][model_type]]
|
models += [c[1][model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def load_additional_models(positive, negative):
|
def load_additional_models(positive, negative, dtype):
|
||||||
"""loads additional models in positive and negative conditioning"""
|
"""loads additional models in positive and negative conditioning"""
|
||||||
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
||||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1].to(dtype) for x in gligen]
|
||||||
models = control_nets + gligen
|
models = control_nets + gligen
|
||||||
comfy.model_management.load_controlnet_gpu(models)
|
comfy.model_management.load_controlnet_gpu(models)
|
||||||
return models
|
return models
|
||||||
@ -65,7 +65,7 @@ def cleanup_additional_models(models):
|
|||||||
for m in models:
|
for m in models:
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
if noise_mask is not None:
|
if noise_mask is not None:
|
||||||
@ -81,11 +81,11 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
||||||
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
||||||
|
|
||||||
models = load_additional_models(positive, negative)
|
models = load_additional_models(positive, negative, model.model_dtype())
|
||||||
|
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
|
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from .k_diffusion import sampling as k_diffusion_sampling
|
|||||||
from .k_diffusion import external as k_diffusion_external
|
from .k_diffusion import external as k_diffusion_external
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
import torch
|
import torch
|
||||||
import contextlib
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
@ -13,7 +12,7 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
@ -292,8 +291,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -301,11 +300,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -375,7 +374,7 @@ def resolve_cond_masks(conditions, h, w, device):
|
|||||||
modified = c[1].copy()
|
modified = c[1].copy()
|
||||||
if len(mask.shape) == 2:
|
if len(mask.shape) == 2:
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
if mask.shape[2] != h or mask.shape[3] != w:
|
if mask.shape[1] != h or mask.shape[2] != w:
|
||||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
||||||
|
|
||||||
if modified.get("set_area_to_bounds", False):
|
if modified.get("set_area_to_bounds", False):
|
||||||
@ -483,8 +482,8 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
|||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -542,7 +541,7 @@ class KSampler:
|
|||||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||||
self.sigmas = sigmas[-(steps + 1):]
|
self.sigmas = sigmas[-(steps + 1):]
|
||||||
|
|
||||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||||
if sigmas is None:
|
if sigmas is None:
|
||||||
sigmas = self.sigmas
|
sigmas = self.sigmas
|
||||||
sigma_min = self.sigma_min
|
sigma_min = self.sigma_min
|
||||||
@ -577,11 +576,6 @@ class KSampler:
|
|||||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if self.model.get_dtype() == torch.float16:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = contextlib.nullcontext
|
|
||||||
|
|
||||||
if self.model.is_adm():
|
if self.model.is_adm():
|
||||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
||||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
||||||
@ -589,7 +583,7 @@ class KSampler:
|
|||||||
if latent_image is not None:
|
if latent_image is not None:
|
||||||
latent_image = self.model.process_latent_in(latent_image)
|
latent_image = self.model.process_latent_in(latent_image)
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
|
||||||
|
|
||||||
cond_concat = None
|
cond_concat = None
|
||||||
if hasattr(self.model, 'concat_keys'): #inpaint
|
if hasattr(self.model, 'concat_keys'): #inpaint
|
||||||
@ -612,67 +606,67 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
max_denoise = True
|
max_denoise = True
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
|
||||||
if self.sampler == "uni_pc":
|
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
|
||||||
elif self.sampler == "uni_pc_bh2":
|
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
|
||||||
elif self.sampler == "ddim":
|
|
||||||
timesteps = []
|
|
||||||
for s in range(sigmas.shape[0]):
|
|
||||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
|
||||||
noise_mask = None
|
|
||||||
if denoise_mask is not None:
|
|
||||||
noise_mask = 1.0 - denoise_mask
|
|
||||||
|
|
||||||
ddim_callback = None
|
if self.sampler == "uni_pc":
|
||||||
if callback is not None:
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||||
total_steps = len(timesteps) - 1
|
elif self.sampler == "uni_pc_bh2":
|
||||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||||
|
elif self.sampler == "ddim":
|
||||||
|
timesteps = []
|
||||||
|
for s in range(sigmas.shape[0]):
|
||||||
|
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||||
|
noise_mask = None
|
||||||
|
if denoise_mask is not None:
|
||||||
|
noise_mask = 1.0 - denoise_mask
|
||||||
|
|
||||||
sampler = DDIMSampler(self.model, device=self.device)
|
ddim_callback = None
|
||||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
if callback is not None:
|
||||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
total_steps = len(timesteps) - 1
|
||||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||||
conditioning=positive,
|
|
||||||
batch_size=noise.shape[0],
|
|
||||||
shape=noise.shape[1:],
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=cfg,
|
|
||||||
unconditional_conditioning=negative,
|
|
||||||
eta=0.0,
|
|
||||||
x_T=z_enc,
|
|
||||||
x0=latent_image,
|
|
||||||
img_callback=ddim_callback,
|
|
||||||
denoise_function=sampling_function,
|
|
||||||
extra_args=extra_args,
|
|
||||||
mask=noise_mask,
|
|
||||||
to_zero=sigmas[-1]==0,
|
|
||||||
end_step=sigmas.shape[0] - 1,
|
|
||||||
disable_pbar=disable_pbar)
|
|
||||||
|
|
||||||
|
sampler = DDIMSampler(self.model, device=self.device)
|
||||||
|
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||||
|
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||||
|
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||||
|
conditioning=positive,
|
||||||
|
batch_size=noise.shape[0],
|
||||||
|
shape=noise.shape[1:],
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=cfg,
|
||||||
|
unconditional_conditioning=negative,
|
||||||
|
eta=0.0,
|
||||||
|
x_T=z_enc,
|
||||||
|
x0=latent_image,
|
||||||
|
img_callback=ddim_callback,
|
||||||
|
denoise_function=sampling_function,
|
||||||
|
extra_args=extra_args,
|
||||||
|
mask=noise_mask,
|
||||||
|
to_zero=sigmas[-1]==0,
|
||||||
|
end_step=sigmas.shape[0] - 1,
|
||||||
|
disable_pbar=disable_pbar)
|
||||||
|
|
||||||
|
else:
|
||||||
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
|
self.model_k.latent_image = latent_image
|
||||||
|
self.model_k.noise = noise
|
||||||
|
|
||||||
|
if max_denoise:
|
||||||
|
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
else:
|
else:
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
noise = noise * sigmas[0]
|
||||||
self.model_k.latent_image = latent_image
|
|
||||||
self.model_k.noise = noise
|
|
||||||
|
|
||||||
if max_denoise:
|
k_callback = None
|
||||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
total_steps = len(sigmas) - 1
|
||||||
else:
|
if callback is not None:
|
||||||
noise = noise * sigmas[0]
|
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||||
|
|
||||||
k_callback = None
|
if latent_image is not None:
|
||||||
total_steps = len(sigmas) - 1
|
noise += latent_image
|
||||||
if callback is not None:
|
if self.sampler == "dpm_fast":
|
||||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
|
elif self.sampler == "dpm_adaptive":
|
||||||
if latent_image is not None:
|
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
noise += latent_image
|
else:
|
||||||
if self.sampler == "dpm_fast":
|
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
elif self.sampler == "dpm_adaptive":
|
|
||||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
else:
|
|
||||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
|
|
||||||
return self.model.process_latent_out(samples.to(torch.float32))
|
return self.model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|||||||
329
comfy/sd.py
329
comfy/sd.py
@ -59,38 +59,8 @@ LORA_CLIP_MAP = {
|
|||||||
"self_attn.out_proj": "self_attn_out_proj",
|
"self_attn.out_proj": "self_attn_out_proj",
|
||||||
}
|
}
|
||||||
|
|
||||||
LORA_UNET_MAP_ATTENTIONS = {
|
|
||||||
"proj_in": "proj_in",
|
|
||||||
"proj_out": "proj_out",
|
|
||||||
}
|
|
||||||
|
|
||||||
transformer_lora_blocks = {
|
def load_lora(lora, to_load):
|
||||||
"transformer_blocks.{}.attn1.to_q": "transformer_blocks_{}_attn1_to_q",
|
|
||||||
"transformer_blocks.{}.attn1.to_k": "transformer_blocks_{}_attn1_to_k",
|
|
||||||
"transformer_blocks.{}.attn1.to_v": "transformer_blocks_{}_attn1_to_v",
|
|
||||||
"transformer_blocks.{}.attn1.to_out.0": "transformer_blocks_{}_attn1_to_out_0",
|
|
||||||
"transformer_blocks.{}.attn2.to_q": "transformer_blocks_{}_attn2_to_q",
|
|
||||||
"transformer_blocks.{}.attn2.to_k": "transformer_blocks_{}_attn2_to_k",
|
|
||||||
"transformer_blocks.{}.attn2.to_v": "transformer_blocks_{}_attn2_to_v",
|
|
||||||
"transformer_blocks.{}.attn2.to_out.0": "transformer_blocks_{}_attn2_to_out_0",
|
|
||||||
"transformer_blocks.{}.ff.net.0.proj": "transformer_blocks_{}_ff_net_0_proj",
|
|
||||||
"transformer_blocks.{}.ff.net.2": "transformer_blocks_{}_ff_net_2",
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
for k in transformer_lora_blocks:
|
|
||||||
LORA_UNET_MAP_ATTENTIONS[k.format(i)] = transformer_lora_blocks[k].format(i)
|
|
||||||
|
|
||||||
|
|
||||||
LORA_UNET_MAP_RESNET = {
|
|
||||||
"in_layers.2": "resnets_{}_conv1",
|
|
||||||
"emb_layers.1": "resnets_{}_time_emb_proj",
|
|
||||||
"out_layers.3": "resnets_{}_conv2",
|
|
||||||
"skip_connection": "resnets_{}_conv_shortcut"
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_lora(path, to_load):
|
|
||||||
lora = utils.load_torch_file(path, safe_load=True)
|
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -189,113 +159,59 @@ def load_lora(path, to_load):
|
|||||||
print("lora key not loaded", x)
|
print("lora key not loaded", x)
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
def model_lora_keys(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.input_blocks.{}.1".format(b)
|
|
||||||
up_counter = 0
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
up_counter += 1
|
|
||||||
if up_counter >= 4:
|
|
||||||
counter += 1
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "diffusion_model.middle_block.1.{}.weight".format(c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
counter = 3
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.output_blocks.{}.1".format(b)
|
|
||||||
up_counter = 0
|
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
|
|
||||||
key_map[lora_key] = k
|
|
||||||
up_counter += 1
|
|
||||||
if up_counter >= 4:
|
|
||||||
counter += 1
|
|
||||||
counter = 0
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
for b in range(24):
|
clip_l_present = False
|
||||||
|
for b in range(32):
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
|
key_map[lora_key] = k
|
||||||
|
clip_l_present = True
|
||||||
|
|
||||||
#Locon stuff
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
ds_counter = 0
|
|
||||||
counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.input_blocks.{}.0".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
|
if clip_l_present:
|
||||||
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
|
else:
|
||||||
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
key_in = True
|
|
||||||
for bb in range(3):
|
|
||||||
k = "{}.{}.op.weight".format(tk[:-2], bb)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
ds_counter += 1
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
for b in range(3):
|
|
||||||
tk = "diffusion_model.middle_block.{}".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
us_counter = 0
|
|
||||||
for b in range(12):
|
|
||||||
tk = "diffusion_model.output_blocks.{}.0".format(b)
|
|
||||||
key_in = False
|
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
|
||||||
k = "{}.{}.weight".format(tk, c)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
|
|
||||||
key_map[lora_key] = k
|
|
||||||
key_in = True
|
|
||||||
for bb in range(3):
|
|
||||||
k = "{}.{}.conv.weight".format(tk[:-2], bb)
|
|
||||||
if k in sdk:
|
|
||||||
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
|
|
||||||
key_map[lora_key] = k
|
|
||||||
us_counter += 1
|
|
||||||
if key_in:
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
def model_lora_keys_unet(model, key_map={}):
|
||||||
|
sdk = model.state_dict().keys()
|
||||||
|
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
|
|
||||||
|
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||||
|
return key_map
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, size=0):
|
def __init__(self, model, load_device, offload_device, size=0):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.patches = []
|
self.patches = []
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
self.model_size()
|
self.model_size()
|
||||||
|
self.load_device = load_device
|
||||||
|
self.offload_device = offload_device
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
@ -310,7 +226,7 @@ class ModelPatcher:
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.size)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
|
||||||
n.patches = self.patches[:]
|
n.patches = self.patches[:]
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
n.model_keys = self.model_keys
|
n.model_keys = self.model_keys
|
||||||
@ -322,6 +238,9 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
def set_model_patch(self, patch, name):
|
def set_model_patch(self, patch, name):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" not in to:
|
if "patches" not in to:
|
||||||
@ -372,7 +291,8 @@ class ModelPatcher:
|
|||||||
patch_list[k] = patch_list[k].to(device)
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.get_dtype()
|
if hasattr(self.model, "get_dtype"):
|
||||||
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
p = {}
|
p = {}
|
||||||
@ -481,10 +401,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = model_lora_keys(model.model)
|
key_map = model_lora_keys_unet(model.model)
|
||||||
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
loaded = load_lora(lora_path, key_map)
|
loaded = load_lora(lora, key_map)
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
new_clip = clip.clone()
|
new_clip = clip.clone()
|
||||||
@ -502,17 +422,22 @@ class CLIP:
|
|||||||
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params
|
params = target.params.copy()
|
||||||
clip = target.clip
|
clip = target.clip
|
||||||
tokenizer = target.tokenizer
|
tokenizer = target.tokenizer
|
||||||
|
|
||||||
self.device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
params["device"] = self.device
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
|
params['device'] = load_device
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
#TODO: make sure this doesn't have a quality loss before enabling.
|
||||||
|
# if model_management.should_use_fp16(load_device):
|
||||||
|
# self.cond_stage_model.half()
|
||||||
|
|
||||||
|
self.cond_stage_model = self.cond_stage_model.to()
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
@ -521,7 +446,6 @@ class CLIP:
|
|||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
n.device = self.device
|
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def load_from_state_dict(self, sd):
|
def load_from_state_dict(self, sd):
|
||||||
@ -539,18 +463,12 @@ class CLIP:
|
|||||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||||
try:
|
|
||||||
self.patcher.patch_model()
|
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
|
||||||
self.patcher.unpatch_model()
|
|
||||||
except Exception as e:
|
|
||||||
self.patcher.unpatch_model()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
cond_out = cond
|
model_management.load_model_gpu(self.patcher)
|
||||||
|
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
if return_pooled:
|
if return_pooled:
|
||||||
return cond_out, pooled
|
return cond, pooled
|
||||||
return cond_out
|
return cond
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
tokens = self.tokenize(text)
|
tokens = self.tokenize(text)
|
||||||
@ -559,6 +477,15 @@ class CLIP:
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.cond_stage_model.load_sd(sd)
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.cond_stage_model.state_dict()
|
||||||
|
|
||||||
|
def patch_model(self):
|
||||||
|
self.patcher.patch_model()
|
||||||
|
|
||||||
|
def unpatch_model(self):
|
||||||
|
self.patcher.unpatch_model()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||||
if config is None:
|
if config is None:
|
||||||
@ -575,8 +502,11 @@ class VAE:
|
|||||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.vae_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.offload_device = model_management.vae_offload_device()
|
||||||
|
self.vae_dtype = model_management.vae_dtype()
|
||||||
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@ -584,7 +514,7 @@ class VAE:
|
|||||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0)
|
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||||
output = torch.clamp((
|
output = torch.clamp((
|
||||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||||
@ -598,7 +528,7 @@ class VAE:
|
|||||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample()
|
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float()
|
||||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
@ -615,13 +545,13 @@ class VAE:
|
|||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu().float()
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
@ -629,7 +559,7 @@ class VAE:
|
|||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return output.movedim(1,-1)
|
return output.movedim(1,-1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
@ -642,14 +572,14 @@ class VAE:
|
|||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu()
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float()
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
@ -657,9 +587,13 @@ class VAE:
|
|||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.first_stage_model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
#print(current_batch_size, target_batch_size)
|
#print(current_batch_size, target_batch_size)
|
||||||
@ -1061,6 +995,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
|
offload_device = model_management.unet_offload_device()
|
||||||
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1083,8 +1019,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_clip_weights(w, state_dict)
|
load_clip_weights(w, state_dict)
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae)
|
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||||
|
|
||||||
|
def calculate_parameters(sd, prefix):
|
||||||
|
params = 0
|
||||||
|
for k in sd.keys():
|
||||||
|
if k.startswith(prefix):
|
||||||
|
params += sd[k].nelement()
|
||||||
|
return params
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||||
sd = utils.load_torch_file(ckpt_path)
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
@ -1095,7 +1037,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
model = None
|
model = None
|
||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
fp16 = model_management.should_use_fp16()
|
parameters = calculate_parameters(sd, "model.diffusion_model.")
|
||||||
|
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
@ -1108,7 +1051,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if output_clipvision:
|
if output_clipvision:
|
||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
model = model_config.get_model(sd)
|
offload_device = model_management.unet_offload_device()
|
||||||
|
model = model_config.get_model(sd, "model.diffusion_model.")
|
||||||
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1129,4 +1074,84 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
print("left over keys:", left_over)
|
print("left over keys:", left_over)
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae, clipvision)
|
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
|
def load_unet(unet_path): #load unet in diffusers format
|
||||||
|
sd = utils.load_torch_file(unet_path)
|
||||||
|
parameters = calculate_parameters(sd, "")
|
||||||
|
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||||
|
|
||||||
|
match = {}
|
||||||
|
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
||||||
|
match["model_channels"] = sd["conv_in.weight"].shape[0]
|
||||||
|
match["in_channels"] = sd["conv_in.weight"].shape[1]
|
||||||
|
match["adm_in_channels"] = None
|
||||||
|
if "class_embedding.linear_1.weight" in sd:
|
||||||
|
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
|
||||||
|
|
||||||
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
||||||
|
|
||||||
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
||||||
|
|
||||||
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
||||||
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
||||||
|
print("match", match)
|
||||||
|
for unet_config in supported_models:
|
||||||
|
matches = True
|
||||||
|
for k in match:
|
||||||
|
if match[k] != unet_config[k]:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
if matches:
|
||||||
|
diffusers_keys = utils.unet_to_diffusers(unet_config)
|
||||||
|
new_sd = {}
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k in sd:
|
||||||
|
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||||
|
else:
|
||||||
|
print(diffusers_keys[k], k)
|
||||||
|
offload_device = model_management.unet_offload_device()
|
||||||
|
model_config = model_detection.model_config_from_unet_config(unet_config)
|
||||||
|
model = model_config.get_model(new_sd, "")
|
||||||
|
model = model.to(offload_device)
|
||||||
|
model.load_model_weights(new_sd, "")
|
||||||
|
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
|
|
||||||
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
|
try:
|
||||||
|
model.patch_model()
|
||||||
|
clip.patch_model()
|
||||||
|
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||||
|
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||||
|
model.unpatch_model()
|
||||||
|
clip.unpatch_model()
|
||||||
|
except Exception as e:
|
||||||
|
model.unpatch_model()
|
||||||
|
clip.unpatch_model()
|
||||||
|
raise e
|
||||||
|
|||||||
@ -5,24 +5,34 @@ import comfy.ops
|
|||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from . import model_management
|
||||||
|
import contextlib
|
||||||
|
|
||||||
class ClipTokenWeightEncoder:
|
class ClipTokenWeightEncoder:
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
z_empty, _ = self.encode(self.empty_tokens)
|
to_encode = list(self.empty_tokens)
|
||||||
output = []
|
|
||||||
first_pooled = None
|
|
||||||
for x in token_weight_pairs:
|
for x in token_weight_pairs:
|
||||||
tokens = [list(map(lambda a: a[0], x))]
|
tokens = list(map(lambda a: a[0], x))
|
||||||
z, pooled = self.encode(tokens)
|
to_encode.append(tokens)
|
||||||
if first_pooled is None:
|
|
||||||
first_pooled = pooled
|
out, pooled = self.encode(to_encode)
|
||||||
|
z_empty = out[0:1]
|
||||||
|
if pooled.shape[0] > 1:
|
||||||
|
first_pooled = pooled[1:2]
|
||||||
|
else:
|
||||||
|
first_pooled = pooled[0:1]
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for k in range(1, out.shape[0]):
|
||||||
|
z = out[k:k+1]
|
||||||
for i in range(len(z)):
|
for i in range(len(z)):
|
||||||
for j in range(len(z[i])):
|
for j in range(len(z[i])):
|
||||||
weight = x[j][1]
|
weight = token_weight_pairs[k - 1][j][1]
|
||||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||||
output += [z]
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return self.encode(self.empty_tokens)
|
return z_empty, first_pooled
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
@ -46,7 +56,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.transformer = CLIPTextModel(config)
|
self.transformer = CLIPTextModel(config)
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if freeze:
|
if freeze:
|
||||||
self.freeze()
|
self.freeze()
|
||||||
@ -95,7 +104,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
|
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||||
n = token_dict_size
|
n = token_dict_size
|
||||||
for x in embedding_weights:
|
for x in embedding_weights:
|
||||||
@ -106,24 +115,32 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
backup_embeds = self.transformer.get_input_embeddings()
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
|
device = backup_embeds.weight.device
|
||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(self.device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
|
||||||
|
|
||||||
if self.layer == "last":
|
if backup_embeds.weight.dtype != torch.float32:
|
||||||
z = outputs.last_hidden_state
|
precision_scope = torch.autocast
|
||||||
elif self.layer == "pooled":
|
|
||||||
z = outputs.pooler_output[:, None, :]
|
|
||||||
else:
|
else:
|
||||||
z = outputs.hidden_states[self.layer_idx]
|
precision_scope = contextlib.nullcontext
|
||||||
if self.layer_norm_hidden_state:
|
|
||||||
z = self.transformer.text_model.final_layer_norm(z)
|
|
||||||
|
|
||||||
pooled_output = outputs.pooler_output
|
with precision_scope(model_management.get_autocast_device(device)):
|
||||||
if self.text_projection is not None:
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||||
pooled_output = pooled_output @ self.text_projection
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
return z, pooled_output
|
|
||||||
|
if self.layer == "last":
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
elif self.layer == "pooled":
|
||||||
|
z = outputs.pooler_output[:, None, :]
|
||||||
|
else:
|
||||||
|
z = outputs.hidden_states[self.layer_idx]
|
||||||
|
if self.layer_norm_hidden_state:
|
||||||
|
z = self.transformer.text_model.final_layer_norm(z)
|
||||||
|
|
||||||
|
pooled_output = outputs.pooler_output
|
||||||
|
if self.text_projection is not None:
|
||||||
|
pooled_output = pooled_output @ self.text_projection
|
||||||
|
return z.float(), pooled_output.float()
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|||||||
@ -3,9 +3,9 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None):
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||||
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config)
|
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||||
if layer == "last":
|
if layer == "last":
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -3,11 +3,12 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config)
|
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||||
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
|
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
|
||||||
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
self.layer_norm_hidden_state = False
|
self.layer_norm_hidden_state = False
|
||||||
if layer == "last":
|
if layer == "last":
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from . import sdxl_clip
|
|||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
|
||||||
|
from . import diffusers_convert
|
||||||
|
|
||||||
class SD15(supported_models_base.BASE):
|
class SD15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"context_dim": 768,
|
"context_dim": 768,
|
||||||
@ -51,9 +53,9 @@ class SD20(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def v_prediction(self, state_dict):
|
def v_prediction(self, state_dict, prefix=""):
|
||||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||||
out = state_dict[k]
|
out = state_dict[k]
|
||||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||||
return True
|
return True
|
||||||
@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE):
|
|||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
replace_prefix[""] = "cond_stage_model.model."
|
||||||
|
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||||
|
|
||||||
@ -100,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict, prefix=""):
|
||||||
return model_base.SDXLRefiner(self)
|
return model_base.SDXLRefiner(self)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
@ -109,10 +118,18 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
|
|
||||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||||
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||||
|
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||||
|
|
||||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
|
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||||
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||||
|
|
||||||
@ -127,7 +144,7 @@ class SDXL(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict, prefix=""):
|
||||||
return model_base.SDXL(self)
|
return model_base.SDXL(self)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
@ -137,11 +154,25 @@ class SDXL(supported_models_base.BASE):
|
|||||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||||
|
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||||
|
|
||||||
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
keys_to_replace = {}
|
||||||
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
|
for k in state_dict:
|
||||||
|
if k.startswith("clip_l"):
|
||||||
|
state_dict_g[k] = state_dict[k]
|
||||||
|
|
||||||
|
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||||
|
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||||
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class BASE:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def v_prediction(self, state_dict):
|
def v_prediction(self, state_dict, prefix=""):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def inpaint_model(self):
|
def inpaint_model(self):
|
||||||
@ -53,14 +53,26 @@ class BASE:
|
|||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict, prefix=""):
|
||||||
if self.inpaint_model():
|
if self.inpaint_model():
|
||||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
|
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||||
elif self.noise_aug_config is not None:
|
elif self.noise_aug_config is not None:
|
||||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
|
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
|
||||||
else:
|
else:
|
||||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))
|
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "cond_stage_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "model.diffusion_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def process_vae_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "first_stage_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
|||||||
154
comfy/utils.py
154
comfy/utils.py
@ -2,10 +2,10 @@ import torch
|
|||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load:
|
||||||
@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
|
if metadata is not None:
|
||||||
|
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
||||||
|
else:
|
||||||
|
safetensors.torch.save_file(sd, ckpt)
|
||||||
|
|
||||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
keys_to_replace = {
|
keys_to_replace = {
|
||||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||||
@ -64,6 +70,152 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
UNET_MAP_ATTENTIONS = {
|
||||||
|
"proj_in.weight",
|
||||||
|
"proj_in.bias",
|
||||||
|
"proj_out.weight",
|
||||||
|
"proj_out.bias",
|
||||||
|
"norm.weight",
|
||||||
|
"norm.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
TRANSFORMER_BLOCKS = {
|
||||||
|
"norm1.weight",
|
||||||
|
"norm1.bias",
|
||||||
|
"norm2.weight",
|
||||||
|
"norm2.bias",
|
||||||
|
"norm3.weight",
|
||||||
|
"norm3.bias",
|
||||||
|
"attn1.to_q.weight",
|
||||||
|
"attn1.to_k.weight",
|
||||||
|
"attn1.to_v.weight",
|
||||||
|
"attn1.to_out.0.weight",
|
||||||
|
"attn1.to_out.0.bias",
|
||||||
|
"attn2.to_q.weight",
|
||||||
|
"attn2.to_k.weight",
|
||||||
|
"attn2.to_v.weight",
|
||||||
|
"attn2.to_out.0.weight",
|
||||||
|
"attn2.to_out.0.bias",
|
||||||
|
"ff.net.0.proj.weight",
|
||||||
|
"ff.net.0.proj.bias",
|
||||||
|
"ff.net.2.weight",
|
||||||
|
"ff.net.2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNET_MAP_RESNET = {
|
||||||
|
"in_layers.2.weight": "conv1.weight",
|
||||||
|
"in_layers.2.bias": "conv1.bias",
|
||||||
|
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||||
|
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||||
|
"out_layers.3.weight": "conv2.weight",
|
||||||
|
"out_layers.3.bias": "conv2.bias",
|
||||||
|
"skip_connection.weight": "conv_shortcut.weight",
|
||||||
|
"skip_connection.bias": "conv_shortcut.bias",
|
||||||
|
"in_layers.0.weight": "norm1.weight",
|
||||||
|
"in_layers.0.bias": "norm1.bias",
|
||||||
|
"out_layers.0.weight": "norm2.weight",
|
||||||
|
"out_layers.0.bias": "norm2.bias",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNET_MAP_BASIC = {
|
||||||
|
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
|
||||||
|
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
|
||||||
|
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
|
||||||
|
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
|
||||||
|
"input_blocks.0.0.weight": "conv_in.weight",
|
||||||
|
"input_blocks.0.0.bias": "conv_in.bias",
|
||||||
|
"out.0.weight": "conv_norm_out.weight",
|
||||||
|
"out.0.bias": "conv_norm_out.bias",
|
||||||
|
"out.2.weight": "conv_out.weight",
|
||||||
|
"out.2.bias": "conv_out.bias",
|
||||||
|
"time_embed.0.weight": "time_embedding.linear_1.weight",
|
||||||
|
"time_embed.0.bias": "time_embedding.linear_1.bias",
|
||||||
|
"time_embed.2.weight": "time_embedding.linear_2.weight",
|
||||||
|
"time_embed.2.bias": "time_embedding.linear_2.bias"
|
||||||
|
}
|
||||||
|
|
||||||
|
def unet_to_diffusers(unet_config):
|
||||||
|
num_res_blocks = unet_config["num_res_blocks"]
|
||||||
|
attention_resolutions = unet_config["attention_resolutions"]
|
||||||
|
channel_mult = unet_config["channel_mult"]
|
||||||
|
transformer_depth = unet_config["transformer_depth"]
|
||||||
|
num_blocks = len(channel_mult)
|
||||||
|
if isinstance(num_res_blocks, int):
|
||||||
|
num_res_blocks = [num_res_blocks] * num_blocks
|
||||||
|
if isinstance(transformer_depth, int):
|
||||||
|
transformer_depth = [transformer_depth] * num_blocks
|
||||||
|
|
||||||
|
transformers_per_layer = []
|
||||||
|
res = 1
|
||||||
|
for i in range(num_blocks):
|
||||||
|
transformers = 0
|
||||||
|
if res in attention_resolutions:
|
||||||
|
transformers = transformer_depth[i]
|
||||||
|
transformers_per_layer.append(transformers)
|
||||||
|
res *= 2
|
||||||
|
|
||||||
|
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
|
||||||
|
|
||||||
|
diffusers_unet_map = {}
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = 1 + (num_res_blocks[x] + 1) * x
|
||||||
|
for i in range(num_res_blocks[x]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||||
|
if transformers_per_layer[x] > 0:
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(transformers_per_layer[x]):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
n += 1
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||||
|
for t in range(transformers_mid):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||||
|
|
||||||
|
for i, n in enumerate([0, 2]):
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||||
|
|
||||||
|
num_res_blocks = list(reversed(num_res_blocks))
|
||||||
|
transformers_per_layer = list(reversed(transformers_per_layer))
|
||||||
|
for x in range(num_blocks):
|
||||||
|
n = (num_res_blocks[x] + 1) * x
|
||||||
|
l = num_res_blocks[x] + 1
|
||||||
|
for i in range(l):
|
||||||
|
c = 0
|
||||||
|
for b in UNET_MAP_RESNET:
|
||||||
|
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||||
|
c += 1
|
||||||
|
if transformers_per_layer[x] > 0:
|
||||||
|
c += 1
|
||||||
|
for b in UNET_MAP_ATTENTIONS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||||
|
for t in range(transformers_per_layer[x]):
|
||||||
|
for b in TRANSFORMER_BLOCKS:
|
||||||
|
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||||
|
if i == l - 1:
|
||||||
|
for k in ["weight", "bias"]:
|
||||||
|
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
for k in UNET_MAP_BASIC:
|
||||||
|
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
|
||||||
|
|
||||||
|
return diffusers_unet_map
|
||||||
|
|
||||||
|
def convert_sd_to(state_dict, dtype):
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for k in keys:
|
||||||
|
state_dict[k] = state_dict[k].to(dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||||
with open(safetensors_path, "rb") as f:
|
with open(safetensors_path, "rb") as f:
|
||||||
header = f.read(8)
|
header = f.read(8)
|
||||||
|
|||||||
56
comfy_extras/nodes_clip_sdxl.py
Normal file
56
comfy_extras/nodes_clip_sdxl.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
|
class CLIPTextEncodeSDXLRefiner:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||||
|
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"text": ("STRING", {"multiline": True}), "clip": ("CLIP", ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, ascore, width, height, text):
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
|
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
|
||||||
|
|
||||||
|
class CLIPTextEncodeSDXL:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ),
|
||||||
|
"text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
|
||||||
|
tokens = clip.tokenize(text_g)
|
||||||
|
tokens["l"] = clip.tokenize(text_l)["l"]
|
||||||
|
if len(tokens["l"]) != len(tokens["g"]):
|
||||||
|
empty = clip.tokenize("")
|
||||||
|
while len(tokens["l"]) < len(tokens["g"]):
|
||||||
|
tokens["l"] += empty["l"]
|
||||||
|
while len(tokens["l"]) > len(tokens["g"]):
|
||||||
|
tokens["g"] += empty["g"]
|
||||||
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
|
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
|
||||||
|
"CLIPTextEncodeSDXL": CLIPTextEncodeSDXL,
|
||||||
|
}
|
||||||
@ -1,4 +1,8 @@
|
|||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
class ModelMergeSimple:
|
class ModelMergeSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -10,7 +14,7 @@ class ModelMergeSimple:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "merge"
|
FUNCTION = "merge"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/model_merging"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
def merge(self, model1, model2, ratio):
|
def merge(self, model1, model2, ratio):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
@ -31,7 +35,7 @@ class ModelMergeBlocks:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "merge"
|
FUNCTION = "merge"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/model_merging"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
def merge(self, model1, model2, **kwargs):
|
def merge(self, model1, model2, **kwargs):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
@ -42,14 +46,52 @@ class ModelMergeBlocks:
|
|||||||
ratio = default_ratio
|
ratio = default_ratio
|
||||||
k_unet = k[len("diffusion_model."):]
|
k_unet = k[len("diffusion_model."):]
|
||||||
|
|
||||||
|
last_arg_size = 0
|
||||||
for arg in kwargs:
|
for arg in kwargs:
|
||||||
if k_unet.startswith(arg):
|
if k_unet.startswith(arg) and last_arg_size < len(arg):
|
||||||
ratio = kwargs[arg]
|
ratio = kwargs[arg]
|
||||||
|
last_arg_size = len(arg)
|
||||||
|
|
||||||
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class CheckpointSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
prompt_info = ""
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
|
metadata = {"prompt": prompt_info}
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSimple": ModelMergeSimple,
|
"ModelMergeSimple": ModelMergeSimple,
|
||||||
"ModelMergeBlocks": ModelMergeBlocks
|
"ModelMergeBlocks": ModelMergeBlocks,
|
||||||
|
"CheckpointSave": CheckpointSave,
|
||||||
}
|
}
|
||||||
|
|||||||
24
execution.py
24
execution.py
@ -110,7 +110,7 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
@ -125,7 +125,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if input_unique_id not in outputs:
|
||||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
||||||
if result[0] is not True:
|
if result[0] is not True:
|
||||||
# Another node failed further upstream
|
# Another node failed further upstream
|
||||||
return result
|
return result
|
||||||
@ -136,7 +136,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = unique_id
|
server.last_node_id = unique_id
|
||||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
obj = class_def()
|
|
||||||
|
obj = object_storage.get((unique_id, class_type), None)
|
||||||
|
if obj is None:
|
||||||
|
obj = class_def()
|
||||||
|
object_storage[(unique_id, class_type)] = obj
|
||||||
|
|
||||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||||
outputs[unique_id] = output_data
|
outputs[unique_id] = output_data
|
||||||
@ -256,6 +260,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
|
self.object_storage = {}
|
||||||
self.outputs_ui = {}
|
self.outputs_ui = {}
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
self.server = server
|
self.server = server
|
||||||
@ -322,6 +327,17 @@ class PromptExecutor:
|
|||||||
for o in to_delete:
|
for o in to_delete:
|
||||||
d = self.outputs.pop(o)
|
d = self.outputs.pop(o)
|
||||||
del d
|
del d
|
||||||
|
to_delete = []
|
||||||
|
for o in self.object_storage:
|
||||||
|
if o[0] not in prompt:
|
||||||
|
to_delete += [o]
|
||||||
|
else:
|
||||||
|
p = prompt[o[0]]
|
||||||
|
if o[1] != p['class_type']:
|
||||||
|
to_delete += [o]
|
||||||
|
for o in to_delete:
|
||||||
|
d = self.object_storage.pop(o)
|
||||||
|
del d
|
||||||
|
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||||
@ -349,7 +365,7 @@ class PromptExecutor:
|
|||||||
# This call shouldn't raise anything if there's an error deep in
|
# This call shouldn't raise anything if there's an error deep in
|
||||||
# the actual SD code, instead it will report the node where the
|
# the actual SD code, instead it will report the node where the
|
||||||
# error was raised
|
# error was raised
|
||||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||||
if success is not True:
|
if success is not True:
|
||||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||||
break
|
break
|
||||||
|
|||||||
@ -8,9 +8,12 @@ a111:
|
|||||||
checkpoints: models/Stable-diffusion
|
checkpoints: models/Stable-diffusion
|
||||||
configs: models/Stable-diffusion
|
configs: models/Stable-diffusion
|
||||||
vae: models/VAE
|
vae: models/VAE
|
||||||
loras: models/Lora
|
loras: |
|
||||||
|
models/Lora
|
||||||
|
models/LyCORIS
|
||||||
upscale_models: |
|
upscale_models: |
|
||||||
models/ESRGAN
|
models/ESRGAN
|
||||||
|
models/RealESRGAN
|
||||||
models/SwinIR
|
models/SwinIR
|
||||||
embeddings: embeddings
|
embeddings: embeddings
|
||||||
hypernetworks: models/hypernetworks
|
hypernetworks: models/hypernetworks
|
||||||
@ -21,5 +24,3 @@ a111:
|
|||||||
# checkpoints: models/checkpoints
|
# checkpoints: models/checkpoints
|
||||||
# gligen: models/gligen
|
# gligen: models/gligen
|
||||||
# custom_nodes: path/custom_nodes
|
# custom_nodes: path/custom_nodes
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
|||||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||||
|
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
|
||||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||||
|
|||||||
@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self):
|
def __init__(self, latent_rgb_factors):
|
||||||
self.latent_rgb_factors = torch.tensor([
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
||||||
# R G B
|
|
||||||
[0.298, 0.207, 0.208], # L1
|
|
||||||
[0.187, 0.286, 0.173], # L2
|
|
||||||
[-0.158, 0.189, 0.264], # L3
|
|
||||||
[-0.184, -0.271, -0.473], # L4
|
|
||||||
], device="cpu")
|
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
||||||
@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
|
|
||||||
def get_previewer(device):
|
def get_previewer(device, latent_format):
|
||||||
previewer = None
|
previewer = None
|
||||||
method = args.preview_method
|
method = args.preview_method
|
||||||
if method != LatentPreviewMethod.NoPreviews:
|
if method != LatentPreviewMethod.NoPreviews:
|
||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
|
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
@ -86,10 +80,10 @@ def get_previewer(device):
|
|||||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
else:
|
else:
|
||||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
|
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
previewer = Latent2RGBPreviewer()
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
4
main.py
4
main.py
@ -14,10 +14,6 @@ if os.name == "nt":
|
|||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if args.dont_upcast_attention:
|
|
||||||
print("disabling upcasting of attention")
|
|
||||||
os.environ['ATTN_PRECISION'] = "fp16"
|
|
||||||
|
|
||||||
if args.cuda_device is not None:
|
if args.cuda_device is not None:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
print("Set cuda device to:", args.cuda_device)
|
print("Set cuda device to:", args.cuda_device)
|
||||||
|
|||||||
0
models/unet/put_unet_files_here
Normal file
0
models/unet/put_unet_files_here
Normal file
100
nodes.py
100
nodes.py
@ -86,16 +86,52 @@ class ConditioningAverage :
|
|||||||
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||||
|
|
||||||
cond_from = conditioning_from[0][0]
|
cond_from = conditioning_from[0][0]
|
||||||
|
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
|
||||||
|
|
||||||
for i in range(len(conditioning_to)):
|
for i in range(len(conditioning_to)):
|
||||||
t1 = conditioning_to[i][0]
|
t1 = conditioning_to[i][0]
|
||||||
|
pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
|
||||||
t0 = cond_from[:,:t1.shape[1]]
|
t0 = cond_from[:,:t1.shape[1]]
|
||||||
if t0.shape[1] < t1.shape[1]:
|
if t0.shape[1] < t1.shape[1]:
|
||||||
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
|
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
|
||||||
|
|
||||||
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
|
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
|
||||||
|
t_to = conditioning_to[i][1].copy()
|
||||||
|
if pooled_output_from is not None and pooled_output_to is not None:
|
||||||
|
t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
|
||||||
|
elif pooled_output_from is not None:
|
||||||
|
t_to["pooled_output"] = pooled_output_from
|
||||||
|
|
||||||
|
n = [tw, t_to]
|
||||||
|
out.append(n)
|
||||||
|
return (out, )
|
||||||
|
|
||||||
|
class ConditioningConcat:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"conditioning_to": ("CONDITIONING",),
|
||||||
|
"conditioning_from": ("CONDITIONING",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "concat"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def concat(self, conditioning_to, conditioning_from):
|
||||||
|
out = []
|
||||||
|
|
||||||
|
if len(conditioning_from) > 1:
|
||||||
|
print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||||
|
|
||||||
|
cond_from = conditioning_from[0][0]
|
||||||
|
|
||||||
|
for i in range(len(conditioning_to)):
|
||||||
|
t1 = conditioning_to[i][0]
|
||||||
|
tw = torch.cat((t1, cond_from),1)
|
||||||
n = [tw, conditioning_to[i][1].copy()]
|
n = [tw, conditioning_to[i][1].copy()]
|
||||||
out.append(n)
|
out.append(n)
|
||||||
|
|
||||||
return (out, )
|
return (out, )
|
||||||
|
|
||||||
class ConditioningSetArea:
|
class ConditioningSetArea:
|
||||||
@ -152,6 +188,25 @@ class ConditioningSetMask:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class ConditioningZeroOut:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", )}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "zero_out"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def zero_out(self, conditioning):
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
if "pooled_output" in d:
|
||||||
|
d["pooled_output"] = torch.zeros_like(d["pooled_output"])
|
||||||
|
n = [torch.zeros_like(t[0]), d]
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -290,8 +345,7 @@ class SaveLatent:
|
|||||||
output["latent_tensor"] = samples["samples"]
|
output["latent_tensor"] = samples["samples"]
|
||||||
output["latent_format_version_0"] = torch.tensor([])
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@ -375,7 +429,7 @@ class DiffusersLoader:
|
|||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders/deprecated"
|
||||||
|
|
||||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||||
@ -420,6 +474,9 @@ class CLIPSetLastLayer:
|
|||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class LoraLoader:
|
class LoraLoader:
|
||||||
|
def __init__(self):
|
||||||
|
self.loaded_lora = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
@ -438,7 +495,18 @@ class LoraLoader:
|
|||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
lora = None
|
||||||
|
if self.loaded_lora is not None:
|
||||||
|
if self.loaded_lora[0] == lora_path:
|
||||||
|
lora = self.loaded_lora[1]
|
||||||
|
else:
|
||||||
|
del self.loaded_lora
|
||||||
|
|
||||||
|
if lora is None:
|
||||||
|
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||||
|
self.loaded_lora = (lora_path, lora)
|
||||||
|
|
||||||
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@ -516,6 +584,21 @@ class ControlNetApply:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class UNETLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "load_unet"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
def load_unet(self, unet_name):
|
||||||
|
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||||
|
model = comfy.sd.load_unet(unet_path)
|
||||||
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -958,7 +1041,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if preview_format not in ["JPEG", "PNG"]:
|
if preview_format not in ["JPEG", "PNG"]:
|
||||||
preview_format = "JPEG"
|
preview_format = "JPEG"
|
||||||
|
|
||||||
previewer = latent_preview.get_previewer(device)
|
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
def callback(step, x0, x, total_steps):
|
def callback(step, x0, x, total_steps):
|
||||||
@ -969,7 +1052,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
|
|
||||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
|
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
return (out, )
|
return (out, )
|
||||||
@ -1335,6 +1418,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
"LoraLoader": LoraLoader,
|
"LoraLoader": LoraLoader,
|
||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
|
"UNETLoader": UNETLoader,
|
||||||
"DualCLIPLoader": DualCLIPLoader,
|
"DualCLIPLoader": DualCLIPLoader,
|
||||||
"CLIPVisionEncode": CLIPVisionEncode,
|
"CLIPVisionEncode": CLIPVisionEncode,
|
||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
@ -1355,6 +1439,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
"LoadLatent": LoadLatent,
|
"LoadLatent": LoadLatent,
|
||||||
"SaveLatent": SaveLatent,
|
"SaveLatent": SaveLatent,
|
||||||
|
|
||||||
|
"ConditioningZeroOut": ConditioningZeroOut,
|
||||||
|
"ConditioningConcat": ConditioningConcat,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -1516,6 +1603,7 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
if args.monitor_nodes:
|
if args.monitor_nodes:
|
||||||
print("Monitoring custom nodes for modifications.\n")
|
print("Monitoring custom nodes for modifications.\n")
|
||||||
|
|||||||
@ -144,6 +144,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# ESRGAN upscale model\n",
|
"# ESRGAN upscale model\n",
|
||||||
|
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
|
||||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
||||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@ -1468,7 +1468,7 @@ export class ComfyApp {
|
|||||||
this.loadGraphData(JSON.parse(reader.result));
|
this.loadGraphData(JSON.parse(reader.result));
|
||||||
};
|
};
|
||||||
reader.readAsText(file);
|
reader.readAsText(file);
|
||||||
} else if (file.name?.endsWith(".latent")) {
|
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
||||||
const info = await getLatentMetadata(file);
|
const info = await getLatentMetadata(file);
|
||||||
if (info.workflow) {
|
if (info.workflow) {
|
||||||
this.loadGraphData(JSON.parse(info.workflow));
|
this.loadGraphData(JSON.parse(info.workflow));
|
||||||
|
|||||||
@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
|
|||||||
const dataView = new DataView(safetensorsData.buffer);
|
const dataView = new DataView(safetensorsData.buffer);
|
||||||
let header_size = dataView.getUint32(0, true);
|
let header_size = dataView.getUint32(0, true);
|
||||||
let offset = 8;
|
let offset = 8;
|
||||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
|
||||||
r(header.__metadata__);
|
r(header.__metadata__);
|
||||||
};
|
};
|
||||||
|
|
||||||
reader.readAsArrayBuffer(file);
|
var slice = file.slice(0, 1024 * 1024 * 4);
|
||||||
|
reader.readAsArrayBuffer(slice);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -545,7 +545,7 @@ export class ComfyUI {
|
|||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json,image/png,.latent",
|
accept: ".json,image/png,.latent,.safetensors",
|
||||||
style: {display: "none"},
|
style: {display: "none"},
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
onchange: () => {
|
onchange: () => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user