mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
ac153699fe
@ -45,6 +45,8 @@ jobs:
|
||||
sed -i '1i../ComfyUI' ./python310._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable
|
||||
mv python_embeded ComfyUI_windows_portable
|
||||
@ -59,7 +61,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z
|
||||
|
||||
cd ComfyUI_windows_portable
|
||||
|
||||
@ -31,12 +31,14 @@ jobs:
|
||||
echo 'import site' >> ./python311._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python311._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
|
||||
|
||||
mkdir ComfyUI_windows_portable_nightly_pytorch
|
||||
mv python_embeded ComfyUI_windows_portable_nightly_pytorch
|
||||
@ -52,7 +54,7 @@ jobs:
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
|
||||
|
||||
cd ComfyUI_windows_portable_nightly_pytorch
|
||||
|
||||
@ -29,7 +29,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
@ -69,7 +70,7 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
|
||||
|
||||
Just download, extract and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
@ -193,7 +194,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c
|
||||
|
||||
Use ```--preview-method auto``` to enable previews.
|
||||
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
|
||||
@ -41,7 +41,15 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
|
||||
fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16, might lower quality.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
@ -53,7 +61,8 @@ class LatentPreviewMethod(enum.Enum):
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 512,
|
||||
"projection_dim": 1280,
|
||||
"torch_dtype": "float32",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
|
||||
@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict):
|
||||
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.weight")
|
||||
or k.endswith(".self_attn.k_proj.weight")
|
||||
|
||||
@ -3,12 +3,13 @@ import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
||||
from comfy.sd import load_checkpoint
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import diffusers_convert
|
||||
from . import diffusers_convert
|
||||
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
|
||||
@ -215,10 +215,12 @@ class PositionNet(nn.Module):
|
||||
|
||||
def forward(self, boxes, masks, positive_embeddings):
|
||||
B, N, _ = boxes.shape
|
||||
masks = masks.unsqueeze(-1)
|
||||
dtype = self.linears[0].weight.dtype
|
||||
masks = masks.unsqueeze(-1).to(dtype)
|
||||
positive_embeddings = positive_embeddings.to(dtype)
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||
@ -252,7 +254,8 @@ class Gligen(nn.Module):
|
||||
|
||||
if self.lowvram == True:
|
||||
self.position_net.cpu()
|
||||
def func_lowvram(key, x):
|
||||
def func_lowvram(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
module.to(x.device)
|
||||
r = module(x, objs)
|
||||
|
||||
@ -66,6 +66,9 @@ class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
self.cpu_tree = True
|
||||
if "cpu" in kwargs:
|
||||
self.cpu_tree = kwargs.pop("cpu")
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
if seed is None:
|
||||
@ -77,7 +80,10 @@ class BatchedBrownianTree:
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
if self.cpu_tree:
|
||||
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||
else:
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
@ -85,7 +91,11 @@ class BatchedBrownianTree:
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
if self.cpu_tree:
|
||||
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||
else:
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
@ -104,10 +114,10 @@ class BrownianTreeNoiseSampler:
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
||||
self.transform = transform
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
||||
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
||||
@ -543,7 +553,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
@ -613,8 +624,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
@ -649,3 +661,18 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||
|
||||
|
||||
|
||||
@ -9,8 +9,23 @@ class LatentFormat:
|
||||
class SD15(LatentFormat):
|
||||
def __init__(self, scale_factor=0.18215):
|
||||
self.scale_factor = scale_factor
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[0.298, 0.207, 0.208], # L1
|
||||
[0.187, 0.286, 0.173], # L2
|
||||
[-0.158, 0.189, 0.264], # L3
|
||||
[-0.184, -0.271, -0.473], # L4
|
||||
]
|
||||
self.taesd_decoder_name = "taesd_decoder.pth"
|
||||
|
||||
class SDXL(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.13025
|
||||
|
||||
self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL
|
||||
# R G B
|
||||
[0.298, 0.207, 0.208], # L1
|
||||
[0.187, 0.286, 0.173], # L2
|
||||
[-0.158, 0.189, 0.264], # L3
|
||||
[-0.184, -0.271, -0.473], # L4
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
||||
|
||||
@ -180,6 +180,12 @@ class DDIMSampler(object):
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_start)
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
@ -214,7 +220,7 @@ class DDIMSampler(object):
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
|
||||
@ -16,11 +16,14 @@ if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
# CrossAttn precision handling
|
||||
import os
|
||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||
|
||||
from comfy.cli_args import args
|
||||
# CrossAttn precision handling
|
||||
if args.dont_upcast_attention:
|
||||
print("disabling upcasting of attention")
|
||||
_ATTN_PRECISION = "fp16"
|
||||
else:
|
||||
_ATTN_PRECISION = "fp32"
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -275,7 +278,7 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
@ -311,7 +314,7 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
|
||||
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, dtype=dtype),
|
||||
nn.GroupNorm(32, channels, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||
)
|
||||
@ -244,7 +244,7 @@ class ResBlock(TimestepBlock):
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels, dtype=dtype),
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
@ -778,13 +778,13 @@ class UNetModel(nn.Module):
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch, dtype=self.dtype),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
@ -821,7 +821,7 @@ class UNetModel(nn.Module):
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
|
||||
@ -84,7 +84,7 @@ def _summarize_chunk(
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
torch.exp(attn_weights - max_score, out=attn_weights)
|
||||
exp_weights = attn_weights
|
||||
exp_weights = attn_weights.to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
attn_scores /= summed
|
||||
attn_probs = attn_scores
|
||||
|
||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||
return hidden_states_slice
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
|
||||
@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import numpy as np
|
||||
from . import utils
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, v_prediction=False):
|
||||
@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.diffusion_model = UNetModel(**unet_config)
|
||||
self.v_prediction = v_prediction
|
||||
@ -50,7 +52,13 @@ class BaseModel(torch.nn.Module):
|
||||
else:
|
||||
xc = x
|
||||
context = torch.cat(c_crossattn, 1)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
|
||||
dtype = self.get_dtype()
|
||||
xc = xc.to(dtype)
|
||||
t = t.to(dtype)
|
||||
context = context.to(dtype)
|
||||
if c_adm is not None:
|
||||
c_adm = c_adm.to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
@ -83,6 +91,16 @@ class BaseModel(torch.nn.Module):
|
||||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||
if self.get_dtype() == torch.float16:
|
||||
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
||||
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
||||
@ -144,10 +162,10 @@ class SDXLRefiner(BaseModel):
|
||||
|
||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
out.append(self.embedder(torch.Tensor([height])))
|
||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
||||
flat = torch.flatten(torch.cat(out))[None, ]
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
@ -168,11 +186,11 @@ class SDXL(BaseModel):
|
||||
|
||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
out.append(self.embedder(torch.Tensor([height])))
|
||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||
out.append(self.embedder(torch.Tensor([target_width])))
|
||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||
out.append(self.embedder(torch.Tensor([target_height])))
|
||||
out.append(self.embedder(torch.Tensor([target_width])))
|
||||
flat = torch.flatten(torch.cat(out))[None, ]
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
@ -16,13 +16,11 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
num_res_blocks = 2
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
"out_channels": 4,
|
||||
"num_res_blocks": num_res_blocks,
|
||||
"use_spatial_transformer": True,
|
||||
"legacy": False
|
||||
}
|
||||
@ -110,11 +108,13 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
unet_config["context_dim"] = context_dim
|
||||
return unet_config
|
||||
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
def model_config_from_unet_config(unet_config):
|
||||
for model_config in supported_models.models:
|
||||
if model_config.matches(unet_config):
|
||||
return model_config(unet_config)
|
||||
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
return model_config_from_unet_config(unet_config)
|
||||
|
||||
@ -139,7 +139,23 @@ else:
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
def is_nvidia():
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.GPU:
|
||||
if torch.version.cuda:
|
||||
return True
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
try:
|
||||
if is_nvidia():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
@ -155,10 +171,15 @@ elif args.highvram or args.gpu_only:
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
|
||||
FORCE_FP32 = False
|
||||
FORCE_FP16 = False
|
||||
if args.force_fp32:
|
||||
print("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
if args.force_fp16:
|
||||
print("Forcing FP16.")
|
||||
FORCE_FP16 = True
|
||||
|
||||
if lowvram_available:
|
||||
try:
|
||||
import accelerate
|
||||
@ -212,10 +233,9 @@ def unload_model():
|
||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||
model_accelerated = False
|
||||
|
||||
#never unload models from GPU on high vram
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
current_loaded_model.model.cpu()
|
||||
current_loaded_model.model_patches_to("cpu")
|
||||
|
||||
current_loaded_model.model.to(current_loaded_model.offload_device)
|
||||
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
|
||||
@ -225,6 +245,8 @@ def unload_model():
|
||||
n.cpu()
|
||||
current_gpu_controlnets = []
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (768 * 1024 * 1024)
|
||||
|
||||
def load_model_gpu(model):
|
||||
global current_loaded_model
|
||||
@ -240,15 +262,20 @@ def load_model_gpu(model):
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
torch_dev = get_torch_device()
|
||||
torch_dev = model.load_device
|
||||
model.model_patches_to(torch_dev)
|
||||
model.model_patches_to(model.model_dtype())
|
||||
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
|
||||
vram_set_state = vram_state
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = model.model_size()
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
|
||||
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
current_loaded_model = model
|
||||
@ -257,14 +284,14 @@ def load_model_gpu(model):
|
||||
pass
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
real_model.to(get_torch_device())
|
||||
real_model.to(torch_dev)
|
||||
else:
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
||||
@ -307,12 +334,46 @@ def unload_if_low_vram(model):
|
||||
return model.cpu()
|
||||
return model
|
||||
|
||||
def text_encoder_device():
|
||||
def unet_offload_device():
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
|
||||
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def vae_device():
|
||||
return get_torch_device()
|
||||
|
||||
def vae_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def vae_dtype():
|
||||
if args.fp16_vae:
|
||||
return torch.float16
|
||||
elif args.bf16_vae:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
return dev.type
|
||||
@ -347,7 +408,7 @@ def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
#TODO: more reliable way of checking for flash attention?
|
||||
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
||||
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -402,10 +463,29 @@ def mps_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def should_use_fp16():
|
||||
def is_device_cpu(device):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'cpu'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_device_mps(device):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'mps'):
|
||||
return True
|
||||
return False
|
||||
|
||||
def should_use_fp16(device=None, model_params=0):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
|
||||
if FORCE_FP16:
|
||||
return True
|
||||
|
||||
if device is not None: #TODO
|
||||
if is_device_cpu(device) or is_device_mps(device):
|
||||
return False
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
@ -419,10 +499,27 @@ def should_use_fp16():
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
if props.major < 6:
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
#when the model doesn't actually fit on the card
|
||||
#TODO: actually test if GP106 and others have the same type of behavior
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
fp16_works = True
|
||||
|
||||
if fp16_works:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
if props.major < 7:
|
||||
return False
|
||||
|
||||
#FP32 is faster on those cards?
|
||||
#FP16 is just broken on these cards
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
|
||||
for x in nvidia_16_series:
|
||||
if x in props.name:
|
||||
@ -438,7 +535,7 @@ def soft_empty_cache():
|
||||
elif xpu_available:
|
||||
torch.xpu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
||||
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
@ -51,11 +51,11 @@ def get_models_from_cond(cond, model_type):
|
||||
models += [c[1][model_type]]
|
||||
return models
|
||||
|
||||
def load_additional_models(positive, negative):
|
||||
def load_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1] for x in gligen]
|
||||
gligen = [x[1].to(dtype) for x in gligen]
|
||||
models = control_nets + gligen
|
||||
comfy.model_management.load_controlnet_gpu(models)
|
||||
return models
|
||||
@ -65,7 +65,7 @@ def cleanup_additional_models(models):
|
||||
for m in models:
|
||||
m.cleanup()
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
if noise_mask is not None:
|
||||
@ -81,11 +81,11 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
||||
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
||||
|
||||
models = load_additional_models(positive, negative)
|
||||
models = load_additional_models(positive, negative, model.model_dtype())
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.cpu()
|
||||
|
||||
cleanup_additional_models(models)
|
||||
|
||||
@ -2,7 +2,6 @@ from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .k_diffusion import external as k_diffusion_external
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import contextlib
|
||||
from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
@ -13,7 +12,7 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
@ -292,8 +291,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed)
|
||||
return out
|
||||
|
||||
|
||||
@ -301,11 +300,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None):
|
||||
if denoise_mask is not None:
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out *= denoise_mask
|
||||
|
||||
@ -375,7 +374,7 @@ def resolve_cond_masks(conditions, h, w, device):
|
||||
modified = c[1].copy()
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[2] != h or mask.shape[3] != w:
|
||||
if mask.shape[1] != h or mask.shape[2] != w:
|
||||
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
||||
|
||||
if modified.get("set_area_to_bounds", False):
|
||||
@ -483,8 +482,8 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
||||
class KSampler:
|
||||
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
@ -542,7 +541,7 @@ class KSampler:
|
||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
if sigmas is None:
|
||||
sigmas = self.sigmas
|
||||
sigma_min = self.sigma_min
|
||||
@ -577,11 +576,6 @@ class KSampler:
|
||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if self.model.get_dtype() == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
if self.model.is_adm():
|
||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
||||
@ -589,7 +583,7 @@ class KSampler:
|
||||
if latent_image is not None:
|
||||
latent_image = self.model.process_latent_in(latent_image)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
|
||||
|
||||
cond_concat = None
|
||||
if hasattr(self.model, 'concat_keys'): #inpaint
|
||||
@ -612,67 +606,67 @@ class KSampler:
|
||||
else:
|
||||
max_denoise = True
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
conditioning=positive,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg,
|
||||
unconditional_conditioning=negative,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=sampling_function,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
conditioning=positive,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg,
|
||||
unconditional_conditioning=negative,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=sampling_function,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
self.model_k.noise = noise
|
||||
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
self.model_k.noise = noise
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
|
||||
return self.model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
329
comfy/sd.py
329
comfy/sd.py
@ -59,38 +59,8 @@ LORA_CLIP_MAP = {
|
||||
"self_attn.out_proj": "self_attn_out_proj",
|
||||
}
|
||||
|
||||
LORA_UNET_MAP_ATTENTIONS = {
|
||||
"proj_in": "proj_in",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
|
||||
transformer_lora_blocks = {
|
||||
"transformer_blocks.{}.attn1.to_q": "transformer_blocks_{}_attn1_to_q",
|
||||
"transformer_blocks.{}.attn1.to_k": "transformer_blocks_{}_attn1_to_k",
|
||||
"transformer_blocks.{}.attn1.to_v": "transformer_blocks_{}_attn1_to_v",
|
||||
"transformer_blocks.{}.attn1.to_out.0": "transformer_blocks_{}_attn1_to_out_0",
|
||||
"transformer_blocks.{}.attn2.to_q": "transformer_blocks_{}_attn2_to_q",
|
||||
"transformer_blocks.{}.attn2.to_k": "transformer_blocks_{}_attn2_to_k",
|
||||
"transformer_blocks.{}.attn2.to_v": "transformer_blocks_{}_attn2_to_v",
|
||||
"transformer_blocks.{}.attn2.to_out.0": "transformer_blocks_{}_attn2_to_out_0",
|
||||
"transformer_blocks.{}.ff.net.0.proj": "transformer_blocks_{}_ff_net_0_proj",
|
||||
"transformer_blocks.{}.ff.net.2": "transformer_blocks_{}_ff_net_2",
|
||||
}
|
||||
|
||||
for i in range(10):
|
||||
for k in transformer_lora_blocks:
|
||||
LORA_UNET_MAP_ATTENTIONS[k.format(i)] = transformer_lora_blocks[k].format(i)
|
||||
|
||||
|
||||
LORA_UNET_MAP_RESNET = {
|
||||
"in_layers.2": "resnets_{}_conv1",
|
||||
"emb_layers.1": "resnets_{}_time_emb_proj",
|
||||
"out_layers.3": "resnets_{}_conv2",
|
||||
"skip_connection": "resnets_{}_conv_shortcut"
|
||||
}
|
||||
|
||||
def load_lora(path, to_load):
|
||||
lora = utils.load_torch_file(path, safe_load=True)
|
||||
def load_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@ -189,113 +159,59 @@ def load_lora(path, to_load):
|
||||
print("lora key not loaded", x)
|
||||
return patch_dict
|
||||
|
||||
def model_lora_keys(model, key_map={}):
|
||||
def model_lora_keys_clip(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
|
||||
counter = 0
|
||||
for b in range(12):
|
||||
tk = "diffusion_model.input_blocks.{}.1".format(b)
|
||||
up_counter = 0
|
||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP_ATTENTIONS[c])
|
||||
key_map[lora_key] = k
|
||||
up_counter += 1
|
||||
if up_counter >= 4:
|
||||
counter += 1
|
||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||
k = "diffusion_model.middle_block.1.{}.weight".format(c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
||||
key_map[lora_key] = k
|
||||
counter = 3
|
||||
for b in range(12):
|
||||
tk = "diffusion_model.output_blocks.{}.1".format(b)
|
||||
up_counter = 0
|
||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP_ATTENTIONS[c])
|
||||
key_map[lora_key] = k
|
||||
up_counter += 1
|
||||
if up_counter >= 4:
|
||||
counter += 1
|
||||
counter = 0
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
for b in range(24):
|
||||
clip_l_present = False
|
||||
for b in range(32):
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
|
||||
#Locon stuff
|
||||
ds_counter = 0
|
||||
counter = 0
|
||||
for b in range(12):
|
||||
tk = "diffusion_model.input_blocks.{}.0".format(b)
|
||||
key_in = False
|
||||
for c in LORA_UNET_MAP_RESNET:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_down_blocks_{}_{}".format(counter // 2, LORA_UNET_MAP_RESNET[c].format(counter % 2))
|
||||
if clip_l_present:
|
||||
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
else:
|
||||
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||
key_map[lora_key] = k
|
||||
key_in = True
|
||||
for bb in range(3):
|
||||
k = "{}.{}.op.weight".format(tk[:-2], bb)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_down_blocks_{}_downsamplers_0_conv".format(ds_counter)
|
||||
key_map[lora_key] = k
|
||||
ds_counter += 1
|
||||
if key_in:
|
||||
counter += 1
|
||||
|
||||
counter = 0
|
||||
for b in range(3):
|
||||
tk = "diffusion_model.middle_block.{}".format(b)
|
||||
key_in = False
|
||||
for c in LORA_UNET_MAP_RESNET:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_mid_block_{}".format(LORA_UNET_MAP_RESNET[c].format(counter))
|
||||
key_map[lora_key] = k
|
||||
key_in = True
|
||||
if key_in:
|
||||
counter += 1
|
||||
|
||||
counter = 0
|
||||
us_counter = 0
|
||||
for b in range(12):
|
||||
tk = "diffusion_model.output_blocks.{}.0".format(b)
|
||||
key_in = False
|
||||
for c in LORA_UNET_MAP_RESNET:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_up_blocks_{}_{}".format(counter // 3, LORA_UNET_MAP_RESNET[c].format(counter % 3))
|
||||
key_map[lora_key] = k
|
||||
key_in = True
|
||||
for bb in range(3):
|
||||
k = "{}.{}.conv.weight".format(tk[:-2], bb)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_up_blocks_{}_upsamplers_0_conv".format(us_counter)
|
||||
key_map[lora_key] = k
|
||||
us_counter += 1
|
||||
if key_in:
|
||||
counter += 1
|
||||
|
||||
return key_map
|
||||
|
||||
def model_lora_keys_unet(model, key_map={}):
|
||||
sdk = model.state_dict().keys()
|
||||
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = k
|
||||
|
||||
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[:-len(".weight")].replace(".", "_")
|
||||
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
||||
return key_map
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, size=0):
|
||||
def __init__(self, model, load_device, offload_device, size=0):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = []
|
||||
self.backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
@ -310,7 +226,7 @@ class ModelPatcher:
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.size)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
|
||||
n.patches = self.patches[:]
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
@ -322,6 +238,9 @@ class ModelPatcher:
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
@ -372,7 +291,8 @@ class ModelPatcher:
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
return self.model.get_dtype()
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = {}
|
||||
@ -481,10 +401,10 @@ class ModelPatcher:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
|
||||
key_map = model_lora_keys(model.model)
|
||||
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
||||
loaded = load_lora(lora_path, key_map)
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = model_lora_keys_unet(model.model)
|
||||
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
loaded = load_lora(lora, key_map)
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
new_clip = clip.clone()
|
||||
@ -502,17 +422,22 @@ class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params
|
||||
params = target.params.copy()
|
||||
clip = target.clip
|
||||
tokenizer = target.tokenizer
|
||||
|
||||
self.device = model_management.text_encoder_device()
|
||||
params["device"] = self.device
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = load_device
|
||||
self.cond_stage_model = clip(**(params))
|
||||
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
||||
#TODO: make sure this doesn't have a quality loss before enabling.
|
||||
# if model_management.should_use_fp16(load_device):
|
||||
# self.cond_stage_model.half()
|
||||
|
||||
self.cond_stage_model = self.cond_stage_model.to()
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.layer_idx = None
|
||||
|
||||
def clone(self):
|
||||
@ -521,7 +446,6 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
n.layer_idx = self.layer_idx
|
||||
n.device = self.device
|
||||
return n
|
||||
|
||||
def load_from_state_dict(self, sd):
|
||||
@ -539,18 +463,12 @@ class CLIP:
|
||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
try:
|
||||
self.patcher.patch_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
self.patcher.unpatch_model()
|
||||
except Exception as e:
|
||||
self.patcher.unpatch_model()
|
||||
raise e
|
||||
|
||||
cond_out = cond
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
if return_pooled:
|
||||
return cond_out, pooled
|
||||
return cond_out
|
||||
return cond, pooled
|
||||
return cond
|
||||
|
||||
def encode(self, text):
|
||||
tokens = self.tokenize(text)
|
||||
@ -559,6 +477,15 @@ class CLIP:
|
||||
def load_sd(self, sd):
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
|
||||
def patch_model(self):
|
||||
self.patcher.patch_model()
|
||||
|
||||
def unpatch_model(self):
|
||||
self.patcher.unpatch_model()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||
if config is None:
|
||||
@ -575,8 +502,11 @@ class VAE:
|
||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
self.offload_device = model_management.vae_offload_device()
|
||||
self.vae_dtype = model_management.vae_dtype()
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
@ -584,7 +514,7 @@ class VAE:
|
||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0)
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp((
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||
@ -598,7 +528,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample()
|
||||
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float()
|
||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
@ -615,13 +545,13 @@ class VAE:
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.device)
|
||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu().float()
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
@ -629,7 +559,7 @@ class VAE:
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
@ -642,14 +572,14 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu()
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@ -657,9 +587,13 @@ class VAE:
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
@ -1061,6 +995,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
@ -1083,8 +1019,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_clip_weights(w, state_dict)
|
||||
|
||||
return (ModelPatcher(model), clip, vae)
|
||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
|
||||
def calculate_parameters(sd, prefix):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
@ -1095,7 +1037,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model = None
|
||||
clip_target = None
|
||||
|
||||
fp16 = model_management.should_use_fp16()
|
||||
parameters = calculate_parameters(sd, "model.diffusion_model.")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
@ -1108,7 +1051,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_clipvision:
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
model = model_config.get_model(sd)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
@ -1129,4 +1074,84 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if len(left_over) > 0:
|
||||
print("left over keys:", left_over)
|
||||
|
||||
return (ModelPatcher(model), clip, vae, clipvision)
|
||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
parameters = calculate_parameters(sd, "")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
|
||||
match = {}
|
||||
match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
|
||||
match["model_channels"] = sd["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = sd["conv_in.weight"].shape[1]
|
||||
match["adm_in_channels"] = None
|
||||
if "class_embedding.linear_1.weight" in sd:
|
||||
match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': True, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
|
||||
print("match", match)
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
for k in match:
|
||||
if match[k] != unet_config[k]:
|
||||
matches = False
|
||||
break
|
||||
if matches:
|
||||
diffusers_keys = utils.unet_to_diffusers(unet_config)
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
print(diffusers_keys[k], k)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model_config = model_detection.model_config_from_unet_config(unet_config)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
try:
|
||||
model.patch_model()
|
||||
clip.patch_model()
|
||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||
model.unpatch_model()
|
||||
clip.unpatch_model()
|
||||
except Exception as e:
|
||||
model.unpatch_model()
|
||||
clip.unpatch_model()
|
||||
raise e
|
||||
|
||||
@ -5,24 +5,34 @@ import comfy.ops
|
||||
import torch
|
||||
import traceback
|
||||
import zipfile
|
||||
from . import model_management
|
||||
import contextlib
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
z_empty, _ = self.encode(self.empty_tokens)
|
||||
output = []
|
||||
first_pooled = None
|
||||
to_encode = list(self.empty_tokens)
|
||||
for x in token_weight_pairs:
|
||||
tokens = [list(map(lambda a: a[0], x))]
|
||||
z, pooled = self.encode(tokens)
|
||||
if first_pooled is None:
|
||||
first_pooled = pooled
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
z_empty = out[0:1]
|
||||
if pooled.shape[0] > 1:
|
||||
first_pooled = pooled[1:2]
|
||||
else:
|
||||
first_pooled = pooled[0:1]
|
||||
|
||||
output = []
|
||||
for k in range(1, out.shape[0]):
|
||||
z = out[k:k+1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = x[j][1]
|
||||
weight = token_weight_pairs[k - 1][j][1]
|
||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||
output += [z]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return self.encode(self.empty_tokens)
|
||||
return z_empty, first_pooled
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
|
||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
@ -46,7 +56,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
@ -95,7 +104,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
if len(embedding_weights) > 0:
|
||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
|
||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||
n = token_dict_size
|
||||
for x in embedding_weights:
|
||||
@ -106,24 +115,32 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
if backup_embeds.weight.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
pooled_output = pooled_output @ self.text_projection
|
||||
return z, pooled_output
|
||||
with precision_scope(model_management.get_autocast_device(device)):
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
pooled_output = pooled_output @ self.text_projection
|
||||
return z.float(), pooled_output.float()
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
||||
@ -3,9 +3,9 @@ import torch
|
||||
import os
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config)
|
||||
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
if layer == "last":
|
||||
pass
|
||||
|
||||
@ -3,11 +3,12 @@ import torch
|
||||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config)
|
||||
super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.layer_norm_hidden_state = False
|
||||
if layer == "last":
|
||||
pass
|
||||
|
||||
@ -9,6 +9,8 @@ from . import sdxl_clip
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
|
||||
from . import diffusers_convert
|
||||
|
||||
class SD15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
@ -51,9 +53,9 @@ class SD20(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def v_prediction(self, state_dict):
|
||||
def v_prediction(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||
out = state_dict[k]
|
||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
return True
|
||||
@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE):
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix[""] = "cond_stage_model.model."
|
||||
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||
|
||||
@ -100,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXLRefiner(self)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
@ -109,10 +118,18 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||
|
||||
@ -127,7 +144,7 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXL(self)
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
@ -137,11 +154,25 @@ class SDXL(supported_models_base.BASE):
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
keys_to_replace = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
for k in state_dict:
|
||||
if k.startswith("clip_l"):
|
||||
state_dict_g[k] = state_dict[k]
|
||||
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class BASE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def v_prediction(self, state_dict):
|
||||
def v_prediction(self, state_dict, prefix=""):
|
||||
return False
|
||||
|
||||
def inpaint_model(self):
|
||||
@ -53,14 +53,26 @@ class BASE:
|
||||
for x in self.unet_extra_config:
|
||||
self.unet_config[x] = self.unet_extra_config[x]
|
||||
|
||||
def get_model(self, state_dict):
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
if self.inpaint_model():
|
||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
|
||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
elif self.noise_aug_config is not None:
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
else:
|
||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))
|
||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.diffusion_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_vae_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
|
||||
154
comfy/utils.py
154
comfy/utils.py
@ -2,10 +2,10 @@ import torch
|
||||
import math
|
||||
import struct
|
||||
import comfy.checkpoint_pickle
|
||||
import safetensors.torch
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
if safe_load:
|
||||
@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
|
||||
sd = pl_sd
|
||||
return sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
if metadata is not None:
|
||||
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
||||
else:
|
||||
safetensors.torch.save_file(sd, ckpt)
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
@ -64,6 +70,152 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
UNET_MAP_ATTENTIONS = {
|
||||
"proj_in.weight",
|
||||
"proj_in.bias",
|
||||
"proj_out.weight",
|
||||
"proj_out.bias",
|
||||
"norm.weight",
|
||||
"norm.bias",
|
||||
}
|
||||
|
||||
TRANSFORMER_BLOCKS = {
|
||||
"norm1.weight",
|
||||
"norm1.bias",
|
||||
"norm2.weight",
|
||||
"norm2.bias",
|
||||
"norm3.weight",
|
||||
"norm3.bias",
|
||||
"attn1.to_q.weight",
|
||||
"attn1.to_k.weight",
|
||||
"attn1.to_v.weight",
|
||||
"attn1.to_out.0.weight",
|
||||
"attn1.to_out.0.bias",
|
||||
"attn2.to_q.weight",
|
||||
"attn2.to_k.weight",
|
||||
"attn2.to_v.weight",
|
||||
"attn2.to_out.0.weight",
|
||||
"attn2.to_out.0.bias",
|
||||
"ff.net.0.proj.weight",
|
||||
"ff.net.0.proj.bias",
|
||||
"ff.net.2.weight",
|
||||
"ff.net.2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_RESNET = {
|
||||
"in_layers.2.weight": "conv1.weight",
|
||||
"in_layers.2.bias": "conv1.bias",
|
||||
"emb_layers.1.weight": "time_emb_proj.weight",
|
||||
"emb_layers.1.bias": "time_emb_proj.bias",
|
||||
"out_layers.3.weight": "conv2.weight",
|
||||
"out_layers.3.bias": "conv2.bias",
|
||||
"skip_connection.weight": "conv_shortcut.weight",
|
||||
"skip_connection.bias": "conv_shortcut.bias",
|
||||
"in_layers.0.weight": "norm1.weight",
|
||||
"in_layers.0.bias": "norm1.bias",
|
||||
"out_layers.0.weight": "norm2.weight",
|
||||
"out_layers.0.bias": "norm2.bias",
|
||||
}
|
||||
|
||||
UNET_MAP_BASIC = {
|
||||
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
|
||||
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
|
||||
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
|
||||
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
|
||||
"input_blocks.0.0.weight": "conv_in.weight",
|
||||
"input_blocks.0.0.bias": "conv_in.bias",
|
||||
"out.0.weight": "conv_norm_out.weight",
|
||||
"out.0.bias": "conv_norm_out.bias",
|
||||
"out.2.weight": "conv_out.weight",
|
||||
"out.2.bias": "conv_out.bias",
|
||||
"time_embed.0.weight": "time_embedding.linear_1.weight",
|
||||
"time_embed.0.bias": "time_embedding.linear_1.bias",
|
||||
"time_embed.2.weight": "time_embedding.linear_2.weight",
|
||||
"time_embed.2.bias": "time_embedding.linear_2.bias"
|
||||
}
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
attention_resolutions = unet_config["attention_resolutions"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"]
|
||||
num_blocks = len(channel_mult)
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = [num_res_blocks] * num_blocks
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = [transformer_depth] * num_blocks
|
||||
|
||||
transformers_per_layer = []
|
||||
res = 1
|
||||
for i in range(num_blocks):
|
||||
transformers = 0
|
||||
if res in attention_resolutions:
|
||||
transformers = transformer_depth[i]
|
||||
transformers_per_layer.append(transformers)
|
||||
res *= 2
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
n = 1 + (num_res_blocks[x] + 1) * x
|
||||
for i in range(num_res_blocks[x]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||
if transformers_per_layer[x] > 0:
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
n += 1
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
||||
|
||||
i = 0
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
||||
for t in range(transformers_mid):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
||||
|
||||
for i, n in enumerate([0, 2]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||
|
||||
num_res_blocks = list(reversed(num_res_blocks))
|
||||
transformers_per_layer = list(reversed(transformers_per_layer))
|
||||
for x in range(num_blocks):
|
||||
n = (num_res_blocks[x] + 1) * x
|
||||
l = num_res_blocks[x] + 1
|
||||
for i in range(l):
|
||||
c = 0
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||
c += 1
|
||||
if transformers_per_layer[x] > 0:
|
||||
c += 1
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
if i == l - 1:
|
||||
for k in ["weight", "bias"]:
|
||||
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
||||
n += 1
|
||||
|
||||
for k in UNET_MAP_BASIC:
|
||||
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
|
||||
|
||||
return diffusers_unet_map
|
||||
|
||||
def convert_sd_to(state_dict, dtype):
|
||||
keys = list(state_dict.keys())
|
||||
for k in keys:
|
||||
state_dict[k] = state_dict[k].to(dtype)
|
||||
return state_dict
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
|
||||
56
comfy_extras/nodes_clip_sdxl.py
Normal file
56
comfy_extras/nodes_clip_sdxl.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class CLIPTextEncodeSDXLRefiner:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"text": ("STRING", {"multiline": True}), "clip": ("CLIP", ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, ascore, width, height, text):
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
|
||||
|
||||
class CLIPTextEncodeSDXL:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ),
|
||||
"text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
|
||||
tokens = clip.tokenize(text_g)
|
||||
tokens["l"] = clip.tokenize(text_l)["l"]
|
||||
if len(tokens["l"]) != len(tokens["g"]):
|
||||
empty = clip.tokenize("")
|
||||
while len(tokens["l"]) < len(tokens["g"]):
|
||||
tokens["l"] += empty["l"]
|
||||
while len(tokens["l"]) > len(tokens["g"]):
|
||||
tokens["g"] += empty["g"]
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
|
||||
"CLIPTextEncodeSDXL": CLIPTextEncodeSDXL,
|
||||
}
|
||||
@ -1,4 +1,8 @@
|
||||
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import json
|
||||
import os
|
||||
|
||||
class ModelMergeSimple:
|
||||
@classmethod
|
||||
@ -10,7 +14,7 @@ class ModelMergeSimple:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "merge"
|
||||
|
||||
CATEGORY = "_for_testing/model_merging"
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, model1, model2, ratio):
|
||||
m = model1.clone()
|
||||
@ -31,7 +35,7 @@ class ModelMergeBlocks:
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "merge"
|
||||
|
||||
CATEGORY = "_for_testing/model_merging"
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def merge(self, model1, model2, **kwargs):
|
||||
m = model1.clone()
|
||||
@ -42,14 +46,52 @@ class ModelMergeBlocks:
|
||||
ratio = default_ratio
|
||||
k_unet = k[len("diffusion_model."):]
|
||||
|
||||
last_arg_size = 0
|
||||
for arg in kwargs:
|
||||
if k_unet.startswith(arg):
|
||||
if k_unet.startswith(arg) and last_arg_size < len(arg):
|
||||
ratio = kwargs[arg]
|
||||
last_arg_size = len(arg)
|
||||
|
||||
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||
return (m, )
|
||||
|
||||
class CheckpointSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
"vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||
return {}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSimple": ModelMergeSimple,
|
||||
"ModelMergeBlocks": ModelMergeBlocks
|
||||
"ModelMergeBlocks": ModelMergeBlocks,
|
||||
"CheckpointSave": CheckpointSave,
|
||||
}
|
||||
|
||||
24
execution.py
24
execution.py
@ -110,7 +110,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
@ -125,7 +125,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
@ -136,7 +136,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
object_storage[(unique_id, class_type)] = obj
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
@ -256,6 +260,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.outputs = {}
|
||||
self.object_storage = {}
|
||||
self.outputs_ui = {}
|
||||
self.old_prompt = {}
|
||||
self.server = server
|
||||
@ -322,6 +327,17 @@ class PromptExecutor:
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
to_delete = []
|
||||
for o in self.object_storage:
|
||||
if o[0] not in prompt:
|
||||
to_delete += [o]
|
||||
else:
|
||||
p = prompt[o[0]]
|
||||
if o[1] != p['class_type']:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.object_storage.pop(o)
|
||||
del d
|
||||
|
||||
for x in prompt:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
@ -349,7 +365,7 @@ class PromptExecutor:
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||
if success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
@ -8,9 +8,12 @@ a111:
|
||||
checkpoints: models/Stable-diffusion
|
||||
configs: models/Stable-diffusion
|
||||
vae: models/VAE
|
||||
loras: models/Lora
|
||||
loras: |
|
||||
models/Lora
|
||||
models/LyCORIS
|
||||
upscale_models: |
|
||||
models/ESRGAN
|
||||
models/RealESRGAN
|
||||
models/SwinIR
|
||||
embeddings: embeddings
|
||||
hypernetworks: models/hypernetworks
|
||||
@ -21,5 +24,3 @@ a111:
|
||||
# checkpoints: models/checkpoints
|
||||
# gligen: models/gligen
|
||||
# custom_nodes: path/custom_nodes
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
|
||||
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
|
||||
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
|
||||
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
|
||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
|
||||
@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
||||
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = torch.tensor([
|
||||
# R G B
|
||||
[0.298, 0.207, 0.208], # L1
|
||||
[0.187, 0.286, 0.173], # L2
|
||||
[-0.158, 0.189, 0.264], # L3
|
||||
[-0.184, -0.271, -0.473], # L4
|
||||
], device="cpu")
|
||||
def __init__(self, latent_rgb_factors):
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
||||
@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
def get_previewer(device):
|
||||
def get_previewer(device, latent_format):
|
||||
previewer = None
|
||||
method = args.preview_method
|
||||
if method != LatentPreviewMethod.NoPreviews:
|
||||
# TODO previewer methods
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
@ -86,10 +80,10 @@ def get_previewer(device):
|
||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
else:
|
||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
|
||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||
|
||||
if previewer is None:
|
||||
previewer = Latent2RGBPreviewer()
|
||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
||||
return previewer
|
||||
|
||||
|
||||
|
||||
4
main.py
4
main.py
@ -14,10 +14,6 @@ if os.name == "nt":
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.dont_upcast_attention:
|
||||
print("disabling upcasting of attention")
|
||||
os.environ['ATTN_PRECISION'] = "fp16"
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
print("Set cuda device to:", args.cuda_device)
|
||||
|
||||
0
models/unet/put_unet_files_here
Normal file
0
models/unet/put_unet_files_here
Normal file
100
nodes.py
100
nodes.py
@ -86,16 +86,52 @@ class ConditioningAverage :
|
||||
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
|
||||
cond_from = conditioning_from[0][0]
|
||||
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
|
||||
|
||||
for i in range(len(conditioning_to)):
|
||||
t1 = conditioning_to[i][0]
|
||||
pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
|
||||
t0 = cond_from[:,:t1.shape[1]]
|
||||
if t0.shape[1] < t1.shape[1]:
|
||||
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
|
||||
|
||||
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
|
||||
t_to = conditioning_to[i][1].copy()
|
||||
if pooled_output_from is not None and pooled_output_to is not None:
|
||||
t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
|
||||
elif pooled_output_from is not None:
|
||||
t_to["pooled_output"] = pooled_output_from
|
||||
|
||||
n = [tw, t_to]
|
||||
out.append(n)
|
||||
return (out, )
|
||||
|
||||
class ConditioningConcat:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning_to": ("CONDITIONING",),
|
||||
"conditioning_from": ("CONDITIONING",),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "concat"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def concat(self, conditioning_to, conditioning_from):
|
||||
out = []
|
||||
|
||||
if len(conditioning_from) > 1:
|
||||
print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
|
||||
cond_from = conditioning_from[0][0]
|
||||
|
||||
for i in range(len(conditioning_to)):
|
||||
t1 = conditioning_to[i][0]
|
||||
tw = torch.cat((t1, cond_from),1)
|
||||
n = [tw, conditioning_to[i][1].copy()]
|
||||
out.append(n)
|
||||
|
||||
return (out, )
|
||||
|
||||
class ConditioningSetArea:
|
||||
@ -152,6 +188,25 @@ class ConditioningSetMask:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class ConditioningZeroOut:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", )}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "zero_out"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def zero_out(self, conditioning):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
if "pooled_output" in d:
|
||||
d["pooled_output"] = torch.zeros_like(d["pooled_output"])
|
||||
n = [torch.zeros_like(t[0]), d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class VAEDecode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -290,8 +345,7 @@ class SaveLatent:
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
output["latent_format_version_0"] = torch.tensor([])
|
||||
|
||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||
|
||||
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||
return {}
|
||||
|
||||
|
||||
@ -375,7 +429,7 @@ class DiffusersLoader:
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
CATEGORY = "advanced/loaders/deprecated"
|
||||
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
@ -420,6 +474,9 @@ class CLIPSetLastLayer:
|
||||
return (clip,)
|
||||
|
||||
class LoraLoader:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
@ -438,7 +495,18 @@ class LoraLoader:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
else:
|
||||
del self.loaded_lora
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class VAELoader:
|
||||
@ -516,6 +584,21 @@ class ControlNetApply:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name):
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -958,7 +1041,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
|
||||
previewer = latent_preview.get_previewer(device)
|
||||
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
@ -969,7 +1052,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
@ -1335,6 +1418,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentCrop": LatentCrop,
|
||||
"LoraLoader": LoraLoader,
|
||||
"CLIPLoader": CLIPLoader,
|
||||
"UNETLoader": UNETLoader,
|
||||
"DualCLIPLoader": DualCLIPLoader,
|
||||
"CLIPVisionEncode": CLIPVisionEncode,
|
||||
"StyleModelApply": StyleModelApply,
|
||||
@ -1355,6 +1439,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
"SaveLatent": SaveLatent,
|
||||
|
||||
"ConditioningZeroOut": ConditioningZeroOut,
|
||||
"ConditioningConcat": ConditioningConcat,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -1516,6 +1603,7 @@ def init_custom_nodes():
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
||||
load_custom_nodes()
|
||||
if args.monitor_nodes:
|
||||
print("Monitoring custom nodes for modifications.\n")
|
||||
|
||||
@ -144,6 +144,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"# ESRGAN upscale model\n",
|
||||
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
||||
"\n",
|
||||
|
||||
@ -1468,7 +1468,7 @@ export class ComfyApp {
|
||||
this.loadGraphData(JSON.parse(reader.result));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
} else if (file.name?.endsWith(".latent")) {
|
||||
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
||||
const info = await getLatentMetadata(file);
|
||||
if (info.workflow) {
|
||||
this.loadGraphData(JSON.parse(info.workflow));
|
||||
|
||||
@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
|
||||
const dataView = new DataView(safetensorsData.buffer);
|
||||
let header_size = dataView.getUint32(0, true);
|
||||
let offset = 8;
|
||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||
let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
|
||||
r(header.__metadata__);
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
var slice = file.slice(0, 1024 * 1024 * 4);
|
||||
reader.readAsArrayBuffer(slice);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -545,7 +545,7 @@ export class ComfyUI {
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png,.latent",
|
||||
accept: ".json,image/png,.latent,.safetensors",
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user