diff --git a/.gitignore b/.gitignore
index 8380a2f7c..38d2ba11b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,8 @@ custom_nodes/
!custom_nodes/example_node.py.example
extra_model_paths.yaml
/.vs
-.idea/
\ No newline at end of file
+.idea/
+venv/
+web/extensions/*
+!web/extensions/logging.js.example
+!web/extensions/core/
\ No newline at end of file
diff --git a/README.md b/README.md
index d9083b7e1..84c10bfe2 100644
--- a/README.md
+++ b/README.md
@@ -87,13 +87,13 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
-At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
-
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
+This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements:
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt```
### NVIDIA
@@ -119,12 +119,22 @@ After this you should have everything installed and can proceed to running Comfy
### Others:
-[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
+#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
-Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own.
+#### Apple Mac silicon
-Directml: ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
+You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
+1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
+1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
+1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
+1. Launch ComfyUI by running `python main.py`.
+
+> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
+
+#### DirectML (AMD Cards on Windows)
+
+```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
@@ -168,16 +178,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
```embedding:embedding_filename.pt```
-### Fedora
-
-To get python 3.10 on fedora:
-```dnf install python3.10```
-
-Then you can:
-
-```python3.10 -m ensurepip```
-
-This will let you use: pip3.10 to install all the dependencies.
## How to increase generation speed?
diff --git a/comfy/checkpoint_pickle.py b/comfy/checkpoint_pickle.py
new file mode 100644
index 000000000..206551d3c
--- /dev/null
+++ b/comfy/checkpoint_pickle.py
@@ -0,0 +1,13 @@
+import pickle
+
+load = pickle.load
+
+class Empty:
+ pass
+
+class Unpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ #TODO: safe unpickle
+ if module.startswith("pytorch_lightning"):
+ return Empty
+ return super().find_class(module, name)
diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py
index cb660ee77..aa667f1aa 100644
--- a/comfy/cldm/cldm.py
+++ b/comfy/cldm/cldm.py
@@ -14,8 +14,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
-from ..ldm.models.diffusion.ddpm import LatentDiffusion
-from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
+from ..ldm.util import exists
class ControlledUnetModel(UNetModel):
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index b56497de0..f1306ef7f 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -59,12 +59,14 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
vram_group = parser.add_mutually_exclusive_group()
+vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
+
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index efb2d5384..2036175b8 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -1,12 +1,15 @@
-from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
+from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
from .utils import load_torch_file, transformers_convert
import os
import torch
+import comfy.ops
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
- self.model = CLIPVisionModelWithProjection(config)
+ with comfy.ops.use_comfy_ops():
+ with modeling_utils.no_init_weights():
+ self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
@@ -18,7 +21,7 @@ class ClipVisionModel():
size=224)
def load_sd(self, sd):
- self.model.load_state_dict(sd, strict=False)
+ return self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
img = torch.clip((255. * image[0]), 0, 255).round().int()
@@ -56,7 +59,13 @@ def load_clipvision_from_sd(sd):
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config)
- clip.load_sd(sd)
+ m, u = clip.load_sd(sd)
+ u = set(u)
+ keys = list(sd.keys())
+ for k in keys:
+ if k not in u:
+ t = sd.pop(k)
+ del t
return clip
def load(ckpt_path):
diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py
index f494f1d30..d6074c7d4 100644
--- a/comfy/diffusers_load.py
+++ b/comfy/diffusers_load.py
@@ -3,7 +3,6 @@ import os
import yaml
import folder_paths
-from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
import os.path as osp
import re
diff --git a/comfy/gligen.py b/comfy/gligen.py
index 8c7cb432e..fe3895c48 100644
--- a/comfy/gligen.py
+++ b/comfy/gligen.py
@@ -260,7 +260,8 @@ class Gligen(nn.Module):
return r
return func_lowvram
else:
- def func(key, x):
+ def func(x, extra_options):
+ key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func
diff --git a/comfy/k_diffusion/augmentation.py b/comfy/k_diffusion/augmentation.py
deleted file mode 100644
index 7dd17c686..000000000
--- a/comfy/k_diffusion/augmentation.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from functools import reduce
-import math
-import operator
-
-import numpy as np
-from skimage import transform
-import torch
-from torch import nn
-
-
-def translate2d(tx, ty):
- mat = [[1, 0, tx],
- [0, 1, ty],
- [0, 0, 1]]
- return torch.tensor(mat, dtype=torch.float32)
-
-
-def scale2d(sx, sy):
- mat = [[sx, 0, 0],
- [ 0, sy, 0],
- [ 0, 0, 1]]
- return torch.tensor(mat, dtype=torch.float32)
-
-
-def rotate2d(theta):
- mat = [[torch.cos(theta), torch.sin(-theta), 0],
- [torch.sin(theta), torch.cos(theta), 0],
- [ 0, 0, 1]]
- return torch.tensor(mat, dtype=torch.float32)
-
-
-class KarrasAugmentationPipeline:
- def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
- self.a_prob = a_prob
- self.a_scale = a_scale
- self.a_aniso = a_aniso
- self.a_trans = a_trans
-
- def __call__(self, image):
- h, w = image.size
- mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
-
- # x-flip
- a0 = torch.randint(2, []).float()
- mats.append(scale2d(1 - 2 * a0, 1))
- # y-flip
- do = (torch.rand([]) < self.a_prob).float()
- a1 = torch.randint(2, []).float() * do
- mats.append(scale2d(1, 1 - 2 * a1))
- # scaling
- do = (torch.rand([]) < self.a_prob).float()
- a2 = torch.randn([]) * do
- mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
- # rotation
- do = (torch.rand([]) < self.a_prob).float()
- a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
- mats.append(rotate2d(-a3))
- # anisotropy
- do = (torch.rand([]) < self.a_prob).float()
- a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
- a5 = torch.randn([]) * do
- mats.append(rotate2d(a4))
- mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
- mats.append(rotate2d(-a4))
- # translation
- do = (torch.rand([]) < self.a_prob).float()
- a6 = torch.randn([]) * do
- a7 = torch.randn([]) * do
- mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
-
- # form the transformation matrix and conditioning vector
- mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
- mat = reduce(operator.matmul, mats)
- cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
-
- # apply the transformation
- image_orig = np.array(image, dtype=np.float32) / 255
- if image_orig.ndim == 2:
- image_orig = image_orig[..., None]
- tf = transform.AffineTransform(mat.numpy())
- image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
- image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
- image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
- return image, image_orig, cond
-
-
-class KarrasAugmentWrapper(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
-
- def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
- if aug_cond is None:
- aug_cond = input.new_zeros([input.shape[0], 9])
- if mapping_cond is None:
- mapping_cond = aug_cond
- else:
- mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
- return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
-
- def set_skip_stages(self, skip_stages):
- return self.inner_model.set_skip_stages(skip_stages)
-
- def set_patch_size(self, patch_size):
- return self.inner_model.set_patch_size(patch_size)
diff --git a/comfy/k_diffusion/config.py b/comfy/k_diffusion/config.py
deleted file mode 100644
index 4b504d6d7..000000000
--- a/comfy/k_diffusion/config.py
+++ /dev/null
@@ -1,110 +0,0 @@
-from functools import partial
-import json
-import math
-import warnings
-
-from jsonmerge import merge
-
-from . import augmentation, layers, models, utils
-
-
-def load_config(file):
- defaults = {
- 'model': {
- 'sigma_data': 1.,
- 'patch_size': 1,
- 'dropout_rate': 0.,
- 'augment_wrapper': True,
- 'augment_prob': 0.,
- 'mapping_cond_dim': 0,
- 'unet_cond_dim': 0,
- 'cross_cond_dim': 0,
- 'cross_attn_depths': None,
- 'skip_stages': 0,
- 'has_variance': False,
- },
- 'dataset': {
- 'type': 'imagefolder',
- },
- 'optimizer': {
- 'type': 'adamw',
- 'lr': 1e-4,
- 'betas': [0.95, 0.999],
- 'eps': 1e-6,
- 'weight_decay': 1e-3,
- },
- 'lr_sched': {
- 'type': 'inverse',
- 'inv_gamma': 20000.,
- 'power': 1.,
- 'warmup': 0.99,
- },
- 'ema_sched': {
- 'type': 'inverse',
- 'power': 0.6667,
- 'max_value': 0.9999
- },
- }
- config = json.load(file)
- return merge(defaults, config)
-
-
-def make_model(config):
- config = config['model']
- assert config['type'] == 'image_v1'
- model = models.ImageDenoiserModelV1(
- config['input_channels'],
- config['mapping_out'],
- config['depths'],
- config['channels'],
- config['self_attn_depths'],
- config['cross_attn_depths'],
- patch_size=config['patch_size'],
- dropout_rate=config['dropout_rate'],
- mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
- unet_cond_dim=config['unet_cond_dim'],
- cross_cond_dim=config['cross_cond_dim'],
- skip_stages=config['skip_stages'],
- has_variance=config['has_variance'],
- )
- if config['augment_wrapper']:
- model = augmentation.KarrasAugmentWrapper(model)
- return model
-
-
-def make_denoiser_wrapper(config):
- config = config['model']
- sigma_data = config.get('sigma_data', 1.)
- has_variance = config.get('has_variance', False)
- if not has_variance:
- return partial(layers.Denoiser, sigma_data=sigma_data)
- return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
-
-
-def make_sample_density(config):
- sd_config = config['sigma_sample_density']
- sigma_data = config['sigma_data']
- if sd_config['type'] == 'lognormal':
- loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
- scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
- return partial(utils.rand_log_normal, loc=loc, scale=scale)
- if sd_config['type'] == 'loglogistic':
- loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
- scale = sd_config['scale'] if 'scale' in sd_config else 0.5
- min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
- max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
- return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
- if sd_config['type'] == 'loguniform':
- min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
- max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
- return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
- if sd_config['type'] == 'v-diffusion':
- min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
- max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
- return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
- if sd_config['type'] == 'split-lognormal':
- loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
- scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
- scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
- return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
- raise ValueError('Unknown sample density type')
diff --git a/comfy/k_diffusion/evaluation.py b/comfy/k_diffusion/evaluation.py
deleted file mode 100644
index 2c34bbf16..000000000
--- a/comfy/k_diffusion/evaluation.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import math
-import os
-from pathlib import Path
-
-from cleanfid.inception_torchscript import InceptionV3W
-import clip
-from resize_right import resize
-import torch
-from torch import nn
-from torch.nn import functional as F
-from torchvision import transforms
-from tqdm.auto import trange
-
-from . import utils
-
-
-class InceptionV3FeatureExtractor(nn.Module):
- def __init__(self, device='cpu'):
- super().__init__()
- path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
- url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
- digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
- utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
- self.model = InceptionV3W(str(path), resize_inside=False).to(device)
- self.size = (299, 299)
-
- def forward(self, x):
- if x.shape[2:4] != self.size:
- x = resize(x, out_shape=self.size, pad_mode='reflect')
- if x.shape[1] == 1:
- x = torch.cat([x] * 3, dim=1)
- x = (x * 127.5 + 127.5).clamp(0, 255)
- return self.model(x)
-
-
-class CLIPFeatureExtractor(nn.Module):
- def __init__(self, name='ViT-L/14@336px', device='cpu'):
- super().__init__()
- self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
- self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
- std=(0.26862954, 0.26130258, 0.27577711))
- self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
-
- def forward(self, x):
- if x.shape[2:4] != self.size:
- x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
- x = self.normalize(x)
- x = self.model.encode_image(x).float()
- x = F.normalize(x) * x.shape[1] ** 0.5
- return x
-
-
-def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
- n_per_proc = math.ceil(n / accelerator.num_processes)
- feats_all = []
- try:
- for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
- cur_batch_size = min(n - i, batch_size)
- samples = sample_fn(cur_batch_size)[:cur_batch_size]
- feats_all.append(accelerator.gather(extractor_fn(samples)))
- except StopIteration:
- pass
- return torch.cat(feats_all)[:n]
-
-
-def polynomial_kernel(x, y):
- d = x.shape[-1]
- dot = x @ y.transpose(-2, -1)
- return (dot / d + 1) ** 3
-
-
-def squared_mmd(x, y, kernel=polynomial_kernel):
- m = x.shape[-2]
- n = y.shape[-2]
- kxx = kernel(x, x)
- kyy = kernel(y, y)
- kxy = kernel(x, y)
- kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
- kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
- kxy_sum = kxy.sum([-1, -2])
- term_1 = kxx_sum / m / (m - 1)
- term_2 = kyy_sum / n / (n - 1)
- term_3 = kxy_sum * 2 / m / n
- return term_1 + term_2 - term_3
-
-
-@utils.tf32_mode(matmul=False)
-def kid(x, y, max_size=5000):
- x_size, y_size = x.shape[0], y.shape[0]
- n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
- total_mmd = x.new_zeros([])
- for i in range(n_partitions):
- cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
- cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
- total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
- return total_mmd / n_partitions
-
-
-class _MatrixSquareRootEig(torch.autograd.Function):
- @staticmethod
- def forward(ctx, a):
- vals, vecs = torch.linalg.eigh(a)
- ctx.save_for_backward(vals, vecs)
- return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
-
- @staticmethod
- def backward(ctx, grad_output):
- vals, vecs = ctx.saved_tensors
- d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
- vecs_t = vecs.transpose(-2, -1)
- return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
-
-
-def sqrtm_eig(a):
- if a.ndim < 2:
- raise RuntimeError('tensor of matrices must have at least 2 dimensions')
- if a.shape[-2] != a.shape[-1]:
- raise RuntimeError('tensor must be batches of square matrices')
- return _MatrixSquareRootEig.apply(a)
-
-
-@utils.tf32_mode(matmul=False)
-def fid(x, y, eps=1e-8):
- x_mean = x.mean(dim=0)
- y_mean = y.mean(dim=0)
- mean_term = (x_mean - y_mean).pow(2).sum()
- x_cov = torch.cov(x.T)
- y_cov = torch.cov(y.T)
- eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
- x_cov = x_cov + eps_eye
- y_cov = y_cov + eps_eye
- x_cov_sqrt = sqrtm_eig(x_cov)
- cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
- return mean_term + cov_term
diff --git a/comfy/k_diffusion/gns.py b/comfy/k_diffusion/gns.py
deleted file mode 100644
index dcb7b8d8a..000000000
--- a/comfy/k_diffusion/gns.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import torch
-from torch import nn
-
-
-class DDPGradientStatsHook:
- def __init__(self, ddp_module):
- try:
- ddp_module.register_comm_hook(self, self._hook_fn)
- except AttributeError:
- raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
- self._clear_state()
-
- def _clear_state(self):
- self.bucket_sq_norms_small_batch = []
- self.bucket_sq_norms_large_batch = []
-
- @staticmethod
- def _hook_fn(self, bucket):
- buf = bucket.buffer()
- self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
- fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
- def callback(fut):
- buf = fut.value()[0]
- self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
- return buf
- return fut.then(callback)
-
- def get_stats(self):
- sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
- sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
- self._clear_state()
- stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
- torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
- return stats[0].item(), stats[1].item()
-
-
-class GradientNoiseScale:
- """Calculates the gradient noise scale (1 / SNR), or critical batch size,
- from _An Empirical Model of Large-Batch Training_,
- https://arxiv.org/abs/1812.06162).
-
- Args:
- beta (float): The decay factor for the exponential moving averages used to
- calculate the gradient noise scale.
- Default: 0.9998
- eps (float): Added for numerical stability.
- Default: 1e-8
- """
-
- def __init__(self, beta=0.9998, eps=1e-8):
- self.beta = beta
- self.eps = eps
- self.ema_sq_norm = 0.
- self.ema_var = 0.
- self.beta_cumprod = 1.
- self.gradient_noise_scale = float('nan')
-
- def state_dict(self):
- """Returns the state of the object as a :class:`dict`."""
- return dict(self.__dict__.items())
-
- def load_state_dict(self, state_dict):
- """Loads the object's state.
- Args:
- state_dict (dict): object state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- self.__dict__.update(state_dict)
-
- def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
- """Updates the state with a new batch's gradient statistics, and returns the
- current gradient noise scale.
-
- Args:
- sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
- per sample gradients.
- sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
- per sample gradients.
- n_small_batch (int): The batch size of the individual microbatch or per sample
- gradients (1 if per sample).
- n_large_batch (int): The total batch size of the mean of the microbatch or
- per sample gradients.
- """
- est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
- est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
- self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
- self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
- self.beta_cumprod *= self.beta
- self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
- return self.gradient_noise_scale
-
- def get_gns(self):
- """Returns the current gradient noise scale."""
- return self.gradient_noise_scale
-
- def get_stats(self):
- """Returns the current (debiased) estimates of the squared mean gradient
- and gradient variance."""
- return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
diff --git a/comfy/k_diffusion/layers.py b/comfy/k_diffusion/layers.py
deleted file mode 100644
index cdeba0ad6..000000000
--- a/comfy/k_diffusion/layers.py
+++ /dev/null
@@ -1,246 +0,0 @@
-import math
-
-from einops import rearrange, repeat
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from . import utils
-
-# Karras et al. preconditioned denoiser
-
-class Denoiser(nn.Module):
- """A Karras et al. preconditioner for denoising diffusion models."""
-
- def __init__(self, inner_model, sigma_data=1.):
- super().__init__()
- self.inner_model = inner_model
- self.sigma_data = sigma_data
-
- def get_scalings(self, sigma):
- c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
- c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
- c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
- return c_skip, c_out, c_in
-
- def loss(self, input, noise, sigma, **kwargs):
- c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
- noised_input = input + noise * utils.append_dims(sigma, input.ndim)
- model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
- target = (input - c_skip * noised_input) / c_out
- return (model_output - target).pow(2).flatten(1).mean(1)
-
- def forward(self, input, sigma, **kwargs):
- c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
- return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
-
-
-class DenoiserWithVariance(Denoiser):
- def loss(self, input, noise, sigma, **kwargs):
- c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
- noised_input = input + noise * utils.append_dims(sigma, input.ndim)
- model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
- logvar = utils.append_dims(logvar, model_output.ndim)
- target = (input - c_skip * noised_input) / c_out
- losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
- return losses.flatten(1).mean(1)
-
-
-# Residual blocks
-
-class ResidualBlock(nn.Module):
- def __init__(self, *main, skip=None):
- super().__init__()
- self.main = nn.Sequential(*main)
- self.skip = skip if skip else nn.Identity()
-
- def forward(self, input):
- return self.main(input) + self.skip(input)
-
-
-# Noise level (and other) conditioning
-
-class ConditionedModule(nn.Module):
- pass
-
-
-class UnconditionedModule(ConditionedModule):
- def __init__(self, module):
- super().__init__()
- self.module = module
-
- def forward(self, input, cond=None):
- return self.module(input)
-
-
-class ConditionedSequential(nn.Sequential, ConditionedModule):
- def forward(self, input, cond):
- for module in self:
- if isinstance(module, ConditionedModule):
- input = module(input, cond)
- else:
- input = module(input)
- return input
-
-
-class ConditionedResidualBlock(ConditionedModule):
- def __init__(self, *main, skip=None):
- super().__init__()
- self.main = ConditionedSequential(*main)
- self.skip = skip if skip else nn.Identity()
-
- def forward(self, input, cond):
- skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
- return self.main(input, cond) + skip
-
-
-class AdaGN(ConditionedModule):
- def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
- super().__init__()
- self.num_groups = num_groups
- self.eps = eps
- self.cond_key = cond_key
- self.mapper = nn.Linear(feats_in, c_out * 2)
-
- def forward(self, input, cond):
- weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
- input = F.group_norm(input, self.num_groups, eps=self.eps)
- return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
-
-
-# Attention
-
-class SelfAttention2d(ConditionedModule):
- def __init__(self, c_in, n_head, norm, dropout_rate=0.):
- super().__init__()
- assert c_in % n_head == 0
- self.norm_in = norm(c_in)
- self.n_head = n_head
- self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
- self.out_proj = nn.Conv2d(c_in, c_in, 1)
- self.dropout = nn.Dropout(dropout_rate)
-
- def forward(self, input, cond):
- n, c, h, w = input.shape
- qkv = self.qkv_proj(self.norm_in(input, cond))
- qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
- q, k, v = qkv.chunk(3, dim=1)
- scale = k.shape[3] ** -0.25
- att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
- att = self.dropout(att)
- y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
- return input + self.out_proj(y)
-
-
-class CrossAttention2d(ConditionedModule):
- def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
- cond_key='cross', cond_key_padding='cross_padding'):
- super().__init__()
- assert c_dec % n_head == 0
- self.cond_key = cond_key
- self.cond_key_padding = cond_key_padding
- self.norm_enc = nn.LayerNorm(c_enc)
- self.norm_dec = norm_dec(c_dec)
- self.n_head = n_head
- self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
- self.kv_proj = nn.Linear(c_enc, c_dec * 2)
- self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
- self.dropout = nn.Dropout(dropout_rate)
-
- def forward(self, input, cond):
- n, c, h, w = input.shape
- q = self.q_proj(self.norm_dec(input, cond))
- q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
- kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
- kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
- k, v = kv.chunk(2, dim=1)
- scale = k.shape[3] ** -0.25
- att = ((q * scale) @ (k.transpose(2, 3) * scale))
- att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
- att = att.softmax(3)
- att = self.dropout(att)
- y = (att @ v).transpose(2, 3)
- y = y.contiguous().view([n, c, h, w])
- return input + self.out_proj(y)
-
-
-# Downsampling/upsampling
-
-_kernels = {
- 'linear':
- [1 / 8, 3 / 8, 3 / 8, 1 / 8],
- 'cubic':
- [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
- 0.43359375, 0.11328125, -0.03515625, -0.01171875],
- 'lanczos3':
- [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
- -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
- 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
- -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
-}
-_kernels['bilinear'] = _kernels['linear']
-_kernels['bicubic'] = _kernels['cubic']
-
-
-class Downsample2d(nn.Module):
- def __init__(self, kernel='linear', pad_mode='reflect'):
- super().__init__()
- self.pad_mode = pad_mode
- kernel_1d = torch.tensor([_kernels[kernel]])
- self.pad = kernel_1d.shape[1] // 2 - 1
- self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
-
- def forward(self, x):
- x = F.pad(x, (self.pad,) * 4, self.pad_mode)
- weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
- indices = torch.arange(x.shape[1], device=x.device)
- weight[indices, indices] = self.kernel.to(weight)
- return F.conv2d(x, weight, stride=2)
-
-
-class Upsample2d(nn.Module):
- def __init__(self, kernel='linear', pad_mode='reflect'):
- super().__init__()
- self.pad_mode = pad_mode
- kernel_1d = torch.tensor([_kernels[kernel]]) * 2
- self.pad = kernel_1d.shape[1] // 2 - 1
- self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
-
- def forward(self, x):
- x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
- weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
- indices = torch.arange(x.shape[1], device=x.device)
- weight[indices, indices] = self.kernel.to(weight)
- return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
-
-
-# Embeddings
-
-class FourierFeatures(nn.Module):
- def __init__(self, in_features, out_features, std=1.):
- super().__init__()
- assert out_features % 2 == 0
- self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
-
- def forward(self, input):
- f = 2 * math.pi * input @ self.weight.T
- return torch.cat([f.cos(), f.sin()], dim=-1)
-
-
-# U-Nets
-
-class UNet(ConditionedModule):
- def __init__(self, d_blocks, u_blocks, skip_stages=0):
- super().__init__()
- self.d_blocks = nn.ModuleList(d_blocks)
- self.u_blocks = nn.ModuleList(u_blocks)
- self.skip_stages = skip_stages
-
- def forward(self, input, cond):
- skips = []
- for block in self.d_blocks[self.skip_stages:]:
- input = block(input, cond)
- skips.append(input)
- for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
- input = block(input, cond, skip if i > 0 else None)
- return input
diff --git a/comfy/k_diffusion/models/__init__.py b/comfy/k_diffusion/models/__init__.py
deleted file mode 100644
index 82608ff1d..000000000
--- a/comfy/k_diffusion/models/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .image_v1 import ImageDenoiserModelV1
diff --git a/comfy/k_diffusion/models/image_v1.py b/comfy/k_diffusion/models/image_v1.py
deleted file mode 100644
index 9ffd5f2c4..000000000
--- a/comfy/k_diffusion/models/image_v1.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import math
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from .. import layers, utils
-
-
-def orthogonal_(module):
- nn.init.orthogonal_(module.weight)
- return module
-
-
-class ResConvBlock(layers.ConditionedResidualBlock):
- def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
- skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
- super().__init__(
- layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
- nn.GELU(),
- nn.Conv2d(c_in, c_mid, 3, padding=1),
- nn.Dropout2d(dropout_rate, inplace=True),
- layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
- nn.GELU(),
- nn.Conv2d(c_mid, c_out, 3, padding=1),
- nn.Dropout2d(dropout_rate, inplace=True),
- skip=skip)
-
-
-class DBlock(layers.ConditionedSequential):
- def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
- modules = [nn.Identity()]
- for i in range(n_layers):
- my_c_in = c_in if i == 0 else c_mid
- my_c_out = c_mid if i < n_layers - 1 else c_out
- modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
- if self_attn:
- norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
- modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
- if cross_attn:
- norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
- modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
- super().__init__(*modules)
- self.set_downsample(downsample)
-
- def set_downsample(self, downsample):
- self[0] = layers.Downsample2d() if downsample else nn.Identity()
- return self
-
-
-class UBlock(layers.ConditionedSequential):
- def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
- modules = []
- for i in range(n_layers):
- my_c_in = c_in if i == 0 else c_mid
- my_c_out = c_mid if i < n_layers - 1 else c_out
- modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
- if self_attn:
- norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
- modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
- if cross_attn:
- norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
- modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
- modules.append(nn.Identity())
- super().__init__(*modules)
- self.set_upsample(upsample)
-
- def forward(self, input, cond, skip=None):
- if skip is not None:
- input = torch.cat([input, skip], dim=1)
- return super().forward(input, cond)
-
- def set_upsample(self, upsample):
- self[-1] = layers.Upsample2d() if upsample else nn.Identity()
- return self
-
-
-class MappingNet(nn.Sequential):
- def __init__(self, feats_in, feats_out, n_layers=2):
- layers = []
- for i in range(n_layers):
- layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
- layers.append(nn.GELU())
- super().__init__(*layers)
-
-
-class ImageDenoiserModelV1(nn.Module):
- def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
- super().__init__()
- self.c_in = c_in
- self.channels = channels
- self.unet_cond_dim = unet_cond_dim
- self.patch_size = patch_size
- self.has_variance = has_variance
- self.timestep_embed = layers.FourierFeatures(1, feats_in)
- if mapping_cond_dim > 0:
- self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
- self.mapping = MappingNet(feats_in, feats_in)
- self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
- self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
- nn.init.zeros_(self.proj_out.weight)
- nn.init.zeros_(self.proj_out.bias)
- if cross_cond_dim == 0:
- cross_attn_depths = [False] * len(self_attn_depths)
- d_blocks, u_blocks = [], []
- for i in range(len(depths)):
- my_c_in = channels[max(0, i - 1)]
- d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
- for i in range(len(depths)):
- my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
- my_c_out = channels[max(0, i - 1)]
- u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
- self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
-
- def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
- c_noise = sigma.log() / 4
- timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
- mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
- mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
- cond = {'cond': mapping_out}
- if unet_cond is not None:
- input = torch.cat([input, unet_cond], dim=1)
- if cross_cond is not None:
- cond['cross'] = cross_cond
- cond['cross_padding'] = cross_cond_padding
- if self.patch_size > 1:
- input = F.pixel_unshuffle(input, self.patch_size)
- input = self.proj_in(input)
- input = self.u_net(input, cond)
- input = self.proj_out(input)
- if self.has_variance:
- input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
- if self.patch_size > 1:
- input = F.pixel_shuffle(input, self.patch_size)
- if self.has_variance and return_variance:
- return input, logvar
- return input
-
- def set_skip_stages(self, skip_stages):
- self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
- self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
- nn.init.zeros_(self.proj_out.weight)
- nn.init.zeros_(self.proj_out.bias)
- self.u_net.skip_stages = skip_stages
- for i, block in enumerate(self.u_net.d_blocks):
- block.set_downsample(i > skip_stages)
- for i, block in enumerate(reversed(self.u_net.u_blocks)):
- block.set_upsample(i > skip_stages)
- return self
-
- def set_patch_size(self, patch_size):
- self.patch_size = patch_size
- self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
- self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
- nn.init.zeros_(self.proj_out.weight)
- nn.init.zeros_(self.proj_out.bias)
diff --git a/comfy/k_diffusion/utils.py b/comfy/k_diffusion/utils.py
index ce6014bea..a644df2f3 100644
--- a/comfy/k_diffusion/utils.py
+++ b/comfy/k_diffusion/utils.py
@@ -10,25 +10,6 @@ from PIL import Image
import torch
from torch import nn, optim
from torch.utils import data
-from torchvision.transforms import functional as TF
-
-
-def from_pil_image(x):
- """Converts from a PIL image to a tensor."""
- x = TF.to_tensor(x)
- if x.ndim == 2:
- x = x[..., None]
- return x * 2 - 1
-
-
-def to_pil_image(x):
- """Converts from a tensor to a PIL image."""
- if x.ndim == 4:
- assert x.shape[0] == 1
- x = x[0]
- if x.shape[0] == 1:
- x = x[0]
- return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
diff --git a/comfy/ldm/data/__init__.py b/comfy/ldm/data/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/comfy/ldm/data/util.py b/comfy/ldm/data/util.py
deleted file mode 100644
index 5b60ceb23..000000000
--- a/comfy/ldm/data/util.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import torch
-
-from ldm.modules.midas.api import load_midas_transform
-
-
-class AddMiDaS(object):
- def __init__(self, model_type):
- super().__init__()
- self.transform = load_midas_transform(model_type)
-
- def pt2np(self, x):
- x = ((x + 1.0) * .5).detach().cpu().numpy()
- return x
-
- def np2pt(self, x):
- x = torch.from_numpy(x) * 2 - 1.
- return x
-
- def __call__(self, sample):
- # sample['jpg'] is tensor hwc in [-1, 1] at this point
- x = self.pt2np(sample['jpg'])
- x = self.transform({"image": x})["image"]
- sample['midas_in'] = x
- return sample
\ No newline at end of file
diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py
index c279f2c18..d5649089a 100644
--- a/comfy/ldm/models/diffusion/ddim.py
+++ b/comfy/ldm/models/diffusion/ddim.py
@@ -284,7 +284,7 @@ class DDIMSampler(object):
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
- e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else:
e_t = model_output
@@ -306,7 +306,7 @@ class DDIMSampler(object):
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
- pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+ pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py
deleted file mode 100644
index 0f484a7f1..000000000
--- a/comfy/ldm/models/diffusion/ddpm.py
+++ /dev/null
@@ -1,1875 +0,0 @@
-"""
-wild mixture of
-https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
-https://github.com/CompVis/taming-transformers
--- merci
-"""
-
-import torch
-import torch.nn as nn
-import numpy as np
-# import pytorch_lightning as pl
-from torch.optim.lr_scheduler import LambdaLR
-from einops import rearrange, repeat
-from contextlib import contextmanager, nullcontext
-from functools import partial
-import itertools
-from tqdm import tqdm
-from torchvision.utils import make_grid
-# from pytorch_lightning.utilities.distributed import rank_zero_only
-
-from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
-from comfy.ldm.modules.ema import LitEma
-from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ..autoencoder import IdentityFirstStage, AutoencoderKL
-from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
-from .ddim import DDIMSampler
-
-
-__conditioning_keys__ = {'concat': 'c_concat',
- 'crossattn': 'c_crossattn',
- 'adm': 'y'}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-def uniform_on_device(r1, r2, shape, device):
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
-
-# class DDPM(pl.LightningModule):
-class DDPM(torch.nn.Module):
- # classic DDPM with Gaussian diffusion, in image space
- def __init__(self,
- unet_config,
- timesteps=1000,
- beta_schedule="linear",
- loss_type="l2",
- ckpt_path=None,
- ignore_keys=[],
- load_only_unet=False,
- monitor="val/loss",
- use_ema=True,
- first_stage_key="image",
- image_size=256,
- channels=3,
- log_every_t=100,
- clip_denoised=True,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- given_betas=None,
- original_elbo_weight=0.,
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
- l_simple_weight=1.,
- conditioning_key=None,
- parameterization="eps", # all assuming fixed variance schedules
- scheduler_config=None,
- use_positional_encodings=False,
- learn_logvar=False,
- logvar_init=0.,
- make_it_fit=False,
- ucg_training=None,
- reset_ema=False,
- reset_num_ema_updates=False,
- ):
- super().__init__()
- assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
- self.parameterization = parameterization
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
- self.cond_stage_model = None
- self.clip_denoised = clip_denoised
- self.log_every_t = log_every_t
- self.first_stage_key = first_stage_key
- self.image_size = image_size # try conv?
- self.channels = channels
- self.use_positional_encodings = use_positional_encodings
- self.model = DiffusionWrapper(unet_config, conditioning_key)
- count_params(self.model, verbose=True)
- self.use_ema = use_ema
- if self.use_ema:
- self.model_ema = LitEma(self.model)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- self.use_scheduler = scheduler_config is not None
- if self.use_scheduler:
- self.scheduler_config = scheduler_config
-
- self.v_posterior = v_posterior
- self.original_elbo_weight = original_elbo_weight
- self.l_simple_weight = l_simple_weight
-
- if monitor is not None:
- self.monitor = monitor
- self.make_it_fit = make_it_fit
- if reset_ema: assert exists(ckpt_path)
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
- if reset_ema:
- assert self.use_ema
- print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
- self.model_ema = LitEma(self.model)
- if reset_num_ema_updates:
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
- assert self.use_ema
- self.model_ema.reset_num_updates()
-
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
-
- self.loss_type = loss_type
-
- self.learn_logvar = learn_logvar
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
- if self.learn_logvar:
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
-
- self.ucg_training = ucg_training or dict()
- if self.ucg_training:
- self.ucg_prng = np.random.RandomState()
-
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- if exists(given_betas):
- betas = given_betas
- else:
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
- cosine_s=cosine_s)
- alphas = 1. - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
-
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
-
- to_torch = partial(torch.tensor, dtype=torch.float32)
-
- self.register_buffer('betas', to_torch(betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
- 1. - alphas_cumprod) + self.v_posterior * betas
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
- self.register_buffer('posterior_mean_coef1', to_torch(
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
- self.register_buffer('posterior_mean_coef2', to_torch(
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
-
- if self.parameterization == "eps":
- lvlb_weights = self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
- elif self.parameterization == "x0":
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
- elif self.parameterization == "v":
- lvlb_weights = torch.ones_like(self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
- else:
- raise NotImplementedError("mu not supported")
- lvlb_weights[0] = lvlb_weights[1]
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
- assert not torch.isnan(self.lvlb_weights).all()
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.model.parameters())
- self.model_ema.copy_to(self.model)
- if context is not None:
- print(f"{context}: Switched to EMA weights")
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.model.parameters())
- if context is not None:
- print(f"{context}: Restored training weights")
-
- @torch.no_grad()
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
- sd = torch.load(path, map_location="cpu")
- if "state_dict" in list(sd.keys()):
- sd = sd["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- if self.make_it_fit:
- n_params = len([name for name, _ in
- itertools.chain(self.named_parameters(),
- self.named_buffers())])
- for name, param in tqdm(
- itertools.chain(self.named_parameters(),
- self.named_buffers()),
- desc="Fitting old weights to new weights",
- total=n_params
- ):
- if not name in sd:
- continue
- old_shape = sd[name].shape
- new_shape = param.shape
- assert len(old_shape) == len(new_shape)
- if len(new_shape) > 2:
- # we only modify first two axes
- assert new_shape[2:] == old_shape[2:]
- # assumes first axis corresponds to output dim
- if not new_shape == old_shape:
- new_param = param.clone()
- old_param = sd[name]
- if len(new_shape) == 1:
- for i in range(new_param.shape[0]):
- new_param[i] = old_param[i % old_shape[0]]
- elif len(new_shape) >= 2:
- for i in range(new_param.shape[0]):
- for j in range(new_param.shape[1]):
- new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
-
- n_used_old = torch.ones(old_shape[1])
- for j in range(new_param.shape[1]):
- n_used_old[j % old_shape[1]] += 1
- n_used_new = torch.zeros(new_shape[1])
- for j in range(new_param.shape[1]):
- n_used_new[j] = n_used_old[j % old_shape[1]]
-
- n_used_new = n_used_new[None, :]
- while len(n_used_new.shape) < len(new_shape):
- n_used_new = n_used_new.unsqueeze(-1)
- new_param /= n_used_new
-
- sd[name] = new_param
-
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
- sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys:\n {missing}")
- if len(unexpected) > 0:
- print(f"\nUnexpected Keys:\n {unexpected}")
-
- def q_mean_variance(self, x_start, t):
- """
- Get the distribution q(x_t | x_0).
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
- """
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
- return mean, variance, log_variance
-
- def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
- )
-
- def predict_start_from_z_and_v(self, x_t, t, v):
- # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
- # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
- )
-
- def predict_eps_from_z_and_v(self, x_t, t, v):
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
- )
-
- def q_posterior(self, x_start, x_t, t):
- posterior_mean = (
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
- def p_mean_variance(self, x, t, clip_denoised: bool):
- model_out = self.model(x, t)
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
-
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
- b, *_, device = *x.shape, x.device
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
- noise = noise_like(x.shape, device, repeat_noise)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def p_sample_loop(self, shape, return_intermediates=False):
- device = self.betas.device
- b = shape[0]
- img = torch.randn(shape, device=device)
- intermediates = [img]
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
- clip_denoised=self.clip_denoised)
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
- intermediates.append(img)
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, batch_size=16, return_intermediates=False):
- image_size = self.image_size
- channels = self.channels
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
- return_intermediates=return_intermediates)
-
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
-
- def get_v(self, x, noise, t):
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
- )
-
- def get_loss(self, pred, target, mean=True):
- if self.loss_type == 'l1':
- loss = (target - pred).abs()
- if mean:
- loss = loss.mean()
- elif self.loss_type == 'l2':
- if mean:
- loss = torch.nn.functional.mse_loss(target, pred)
- else:
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
- else:
- raise NotImplementedError("unknown loss type '{loss_type}'")
-
- return loss
-
- def p_losses(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_out = self.model(x_noisy, t)
-
- loss_dict = {}
- if self.parameterization == "eps":
- target = noise
- elif self.parameterization == "x0":
- target = x_start
- elif self.parameterization == "v":
- target = self.get_v(x_start, noise, t)
- else:
- raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
-
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
-
- log_prefix = 'train' if self.training else 'val'
-
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
- loss_simple = loss.mean() * self.l_simple_weight
-
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
-
- loss = loss_simple + self.original_elbo_weight * loss_vlb
-
- loss_dict.update({f'{log_prefix}/loss': loss})
-
- return loss, loss_dict
-
- def forward(self, x, *args, **kwargs):
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- return self.p_losses(x, t, *args, **kwargs)
-
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = rearrange(x, 'b h w c -> b c h w')
- x = x.to(memory_format=torch.contiguous_format).float()
- return x
-
- def shared_step(self, batch):
- x = self.get_input(batch, self.first_stage_key)
- loss, loss_dict = self(x)
- return loss, loss_dict
-
- def training_step(self, batch, batch_idx):
- for k in self.ucg_training:
- p = self.ucg_training[k]["p"]
- val = self.ucg_training[k]["val"]
- if val is None:
- val = ""
- for i in range(len(batch[k])):
- if self.ucg_prng.choice(2, p=[1 - p, p]):
- batch[k][i] = val
-
- loss, loss_dict = self.shared_step(batch)
-
- self.log_dict(loss_dict, prog_bar=True,
- logger=True, on_step=True, on_epoch=True)
-
- self.log("global_step", self.global_step,
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- if self.use_scheduler:
- lr = self.optimizers().param_groups[0]['lr']
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- return loss
-
- @torch.no_grad()
- def validation_step(self, batch, batch_idx):
- _, loss_dict_no_ema = self.shared_step(batch)
- with self.ema_scope():
- _, loss_dict_ema = self.shared_step(batch)
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
-
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self.model)
-
- def _get_rows_from_list(self, samples):
- n_imgs_per_row = len(samples)
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
- x = self.get_input(batch, self.first_stage_key)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- x = x.to(self.device)[:N]
- log["inputs"] = x
-
- # get diffusion row
- diffusion_row = list()
- x_start = x[:n_row]
-
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(x_start)
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- diffusion_row.append(x_noisy)
-
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
-
- if sample:
- # get denoise row
- with self.ema_scope("Plotting"):
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
-
- log["samples"] = samples
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.learn_logvar:
- params = params + [self.logvar]
- opt = torch.optim.AdamW(params, lr=lr)
- return opt
-
-
-class LatentDiffusion(DDPM):
- """main class"""
-
- def __init__(self,
- first_stage_config={},
- cond_stage_config={},
- num_timesteps_cond=None,
- cond_stage_key="image",
- cond_stage_trainable=False,
- concat_mode=True,
- cond_stage_forward=None,
- conditioning_key=None,
- scale_factor=1.0,
- scale_by_std=False,
- force_null_conditioning=False,
- *args, **kwargs):
- self.force_null_conditioning = force_null_conditioning
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
- self.scale_by_std = scale_by_std
- assert self.num_timesteps_cond <= kwargs['timesteps']
- # for backwards compatibility after implementation of DiffusionWrapper
- if conditioning_key is None:
- conditioning_key = 'concat' if concat_mode else 'crossattn'
- if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
- conditioning_key = None
- ckpt_path = kwargs.pop("ckpt_path", None)
- reset_ema = kwargs.pop("reset_ema", False)
- reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
- ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
- self.concat_mode = concat_mode
- self.cond_stage_trainable = cond_stage_trainable
- self.cond_stage_key = cond_stage_key
- try:
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
- self.num_downs = 0
- if not scale_by_std:
- self.scale_factor = scale_factor
- else:
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
-
- # self.instantiate_first_stage(first_stage_config)
- # self.instantiate_cond_stage(cond_stage_config)
-
- self.cond_stage_forward = cond_stage_forward
- self.clip_denoised = False
- self.bbox_tokenizer = None
-
- self.restarted_from_ckpt = False
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys)
- self.restarted_from_ckpt = True
- if reset_ema:
- assert self.use_ema
- print(
- f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
- self.model_ema = LitEma(self.model)
- if reset_num_ema_updates:
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
- assert self.use_ema
- self.model_ema.reset_num_updates()
-
- def make_cond_schedule(self, ):
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
- self.cond_ids[:self.num_timesteps_cond] = ids
-
- # @rank_zero_only
- @torch.no_grad()
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
- # only for very first batch
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
- # set rescale weight to 1./std of encodings
- print("### USING STD-RESCALING ###")
- x = super().get_input(batch, self.first_stage_key)
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
- del self.scale_factor
- self.register_buffer('scale_factor', 1. / z.flatten().std())
- print(f"setting self.scale_factor to {self.scale_factor}")
- print("### USING STD-RESCALING ###")
-
- def register_schedule(self,
- given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
-
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
- if self.shorten_cond_schedule:
- self.make_cond_schedule()
-
- def instantiate_first_stage(self, config):
- model = instantiate_from_config(config)
- self.first_stage_model = model.eval()
- self.first_stage_model.train = disabled_train
- for param in self.first_stage_model.parameters():
- param.requires_grad = False
-
- def instantiate_cond_stage(self, config):
- if not self.cond_stage_trainable:
- if config == "__is_first_stage__":
- print("Using first stage also as cond stage.")
- self.cond_stage_model = self.first_stage_model
- elif config == "__is_unconditional__":
- print(f"Training {self.__class__.__name__} as an unconditional model.")
- self.cond_stage_model = None
- # self.be_unconditional = True
- else:
- model = instantiate_from_config(config)
- self.cond_stage_model = model.eval()
- self.cond_stage_model.train = disabled_train
- for param in self.cond_stage_model.parameters():
- param.requires_grad = False
- else:
- assert config != '__is_first_stage__'
- assert config != '__is_unconditional__'
- model = instantiate_from_config(config)
- self.cond_stage_model = model
-
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
- denoise_row = []
- for zd in tqdm(samples, desc=desc):
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
- force_not_quantize=force_no_decoder_quantization))
- n_imgs_per_row = len(denoise_row)
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- def get_first_stage_encoding(self, encoder_posterior):
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
- z = encoder_posterior.sample()
- elif isinstance(encoder_posterior, torch.Tensor):
- z = encoder_posterior
- else:
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
- return self.scale_factor * z
-
- def get_learned_conditioning(self, c):
- if self.cond_stage_forward is None:
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
- c = self.cond_stage_model.encode(c)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- else:
- c = self.cond_stage_model(c)
- else:
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
- return c
-
- def meshgrid(self, h, w):
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
-
- arr = torch.cat([y, x], dim=-1)
- return arr
-
- def delta_border(self, h, w):
- """
- :param h: height
- :param w: width
- :return: normalized distance to image border,
- wtith min distance = 0 at border and max dist = 0.5 at image center
- """
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
- arr = self.meshgrid(h, w) / lower_right_corner
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
- return edge_dist
-
- def get_weighting(self, h, w, Ly, Lx, device):
- weighting = self.delta_border(h, w)
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
- self.split_input_params["clip_max_weight"], )
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
-
- if self.split_input_params["tie_braker"]:
- L_weighting = self.delta_border(Ly, Lx)
- L_weighting = torch.clip(L_weighting,
- self.split_input_params["clip_min_tie_weight"],
- self.split_input_params["clip_max_tie_weight"])
-
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
- weighting = weighting * L_weighting
- return weighting
-
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
- """
- :param x: img of size (bs, c, h, w)
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
- """
- bs, nc, h, w = x.shape
-
- # number of crops in image
- Ly = (h - kernel_size[0]) // stride[0] + 1
- Lx = (w - kernel_size[1]) // stride[1] + 1
-
- if uf == 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
-
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
-
- elif uf > 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
- dilation=1, padding=0,
- stride=(stride[0] * uf, stride[1] * uf))
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
-
- elif df > 1 and uf == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
- dilation=1, padding=0,
- stride=(stride[0] // df, stride[1] // df))
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
-
- else:
- raise NotImplementedError
-
- return fold, unfold, normalization, weighting
-
- @torch.no_grad()
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
- cond_key=None, return_original_cond=False, bs=None, return_x=False):
- x = super().get_input(batch, k)
- if bs is not None:
- x = x[:bs]
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
-
- if self.model.conditioning_key is not None and not self.force_null_conditioning:
- if cond_key is None:
- cond_key = self.cond_stage_key
- if cond_key != self.first_stage_key:
- if cond_key in ['caption', 'coordinates_bbox', "txt"]:
- xc = batch[cond_key]
- elif cond_key in ['class_label', 'cls']:
- xc = batch
- else:
- xc = super().get_input(batch, cond_key).to(self.device)
- else:
- xc = x
- if not self.cond_stage_trainable or force_c_encode:
- if isinstance(xc, dict) or isinstance(xc, list):
- c = self.get_learned_conditioning(xc)
- else:
- c = self.get_learned_conditioning(xc.to(self.device))
- else:
- c = xc
- if bs is not None:
- c = c[:bs]
-
- if self.use_positional_encodings:
- pos_x, pos_y = self.compute_latent_shifts(batch)
- ckey = __conditioning_keys__[self.model.conditioning_key]
- c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
-
- else:
- c = None
- xc = None
- if self.use_positional_encodings:
- pos_x, pos_y = self.compute_latent_shifts(batch)
- c = {'pos_x': pos_x, 'pos_y': pos_y}
- out = [z, c]
- if return_first_stage_outputs:
- xrec = self.decode_first_stage(z)
- out.extend([x, xrec])
- if return_x:
- out.extend([x])
- if return_original_cond:
- out.append(xc)
- return out
-
- @torch.no_grad()
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
-
- z = 1. / self.scale_factor * z
- return self.first_stage_model.decode(z)
-
- @torch.no_grad()
- def encode_first_stage(self, x):
- return self.first_stage_model.encode(x)
-
- def shared_step(self, batch, **kwargs):
- x, c = self.get_input(batch, self.first_stage_key)
- loss = self(x, c)
- return loss
-
- def forward(self, x, c, *args, **kwargs):
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- if self.model.conditioning_key is not None:
- assert c is not None
- if self.cond_stage_trainable:
- c = self.get_learned_conditioning(c)
- if self.shorten_cond_schedule: # TODO: drop this option
- tc = self.cond_ids[t].to(self.device)
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
- return self.p_losses(x, c, t, *args, **kwargs)
-
- def apply_model(self, x_noisy, t, cond, return_ids=False):
- if isinstance(cond, dict):
- # hybrid case, cond is expected to be a dict
- pass
- else:
- if not isinstance(cond, list):
- cond = [cond]
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
- cond = {key: cond}
-
- x_recon = self.model(x_noisy, t, **cond)
-
- if isinstance(x_recon, tuple) and not return_ids:
- return x_recon[0]
- else:
- return x_recon
-
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
-
- def _prior_bpd(self, x_start):
- """
- Get the prior KL term for the variational lower-bound, measured in
- bits-per-dim.
- This term can't be optimized, as it only depends on the encoder.
- :param x_start: the [N x C x ...] tensor of inputs.
- :return: a batch of [N] KL values (in bits), one per batch element.
- """
- batch_size = x_start.shape[0]
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
- return mean_flat(kl_prior) / np.log(2.0)
-
- def p_losses(self, x_start, cond, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_output = self.apply_model(x_noisy, t, cond)
-
- loss_dict = {}
- prefix = 'train' if self.training else 'val'
-
- if self.parameterization == "x0":
- target = x_start
- elif self.parameterization == "eps":
- target = noise
- elif self.parameterization == "v":
- target = self.get_v(x_start, noise, t)
- else:
- raise NotImplementedError()
-
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
-
- logvar_t = self.logvar[t].to(self.device)
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
- if self.learn_logvar:
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
- loss_dict.update({'logvar': self.logvar.data.mean()})
-
- loss = self.l_simple_weight * loss.mean()
-
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
- loss += (self.original_elbo_weight * loss_vlb)
- loss_dict.update({f'{prefix}/loss': loss})
-
- return loss, loss_dict
-
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
- return_x0=False, score_corrector=None, corrector_kwargs=None):
- t_in = t
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
-
- if score_corrector is not None:
- assert self.parameterization == "eps"
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
-
- if return_codebook_ids:
- model_out, logits = model_out
-
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- else:
- raise NotImplementedError()
-
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
- if quantize_denoised:
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- if return_codebook_ids:
- return model_mean, posterior_variance, posterior_log_variance, logits
- elif return_x0:
- return model_mean, posterior_variance, posterior_log_variance, x_recon
- else:
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
- b, *_, device = *x.shape, x.device
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
- return_codebook_ids=return_codebook_ids,
- quantize_denoised=quantize_denoised,
- return_x0=return_x0,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if return_codebook_ids:
- raise DeprecationWarning("Support dropped.")
- model_mean, _, model_log_variance, logits = outputs
- elif return_x0:
- model_mean, _, model_log_variance, x0 = outputs
- else:
- model_mean, _, model_log_variance = outputs
-
- noise = noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
-
- if return_codebook_ids:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
- if return_x0:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
- else:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
- log_every_t=None):
- if not log_every_t:
- log_every_t = self.log_every_t
- timesteps = self.num_timesteps
- if batch_size is not None:
- b = batch_size if batch_size is not None else shape[0]
- shape = [batch_size] + list(shape)
- else:
- b = batch_size = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=self.device)
- else:
- img = x_T
- intermediates = []
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
- total=timesteps) if verbose else reversed(
- range(0, timesteps))
- if type(temperature) == float:
- temperature = [temperature] * timesteps
-
- for i in iterator:
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img, x0_partial = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised, return_x0=True,
- temperature=temperature[i], noise_dropout=noise_dropout,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if mask is not None:
- assert x0 is not None
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_loop(self, cond, shape, return_intermediates=False,
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, start_T=None,
- log_every_t=None):
-
- if not log_every_t:
- log_every_t = self.log_every_t
- device = self.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- intermediates = [img]
- if timesteps is None:
- timesteps = self.num_timesteps
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
- range(0, timesteps))
-
- if mask is not None:
- assert x0 is not None
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
-
- for i in iterator:
- ts = torch.full((b,), i, device=device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised)
- if mask is not None:
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
-
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
- verbose=True, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, shape=None, **kwargs):
- if shape is None:
- shape = (batch_size, self.channels, self.image_size, self.image_size)
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
- return self.p_sample_loop(cond,
- shape,
- return_intermediates=return_intermediates, x_T=x_T,
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
- mask=mask, x0=x0)
-
- @torch.no_grad()
- def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
- if ddim:
- ddim_sampler = DDIMSampler(self)
- shape = (self.channels, self.image_size, self.image_size)
- samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
- shape, cond, verbose=False, **kwargs)
-
- else:
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
- return_intermediates=True, **kwargs)
-
- return samples, intermediates
-
- @torch.no_grad()
- def get_unconditional_conditioning(self, batch_size, null_label=None):
- if null_label is not None:
- xc = null_label
- # if isinstance(xc, ListConfig):
- # xc = list(xc)
- if isinstance(xc, dict) or isinstance(xc, list):
- c = self.get_learned_conditioning(xc)
- else:
- if hasattr(xc, "to"):
- xc = xc.to(self.device)
- c = self.get_learned_conditioning(xc)
- else:
- if self.cond_stage_key in ["class_label", "cls"]:
- xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
- return self.get_learned_conditioning(xc)
- else:
- raise NotImplementedError("todo")
- if isinstance(c, list): # in case the encoder gives us a list
- for i in range(len(c)):
- c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
- else:
- c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
- return c
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
- use_ema_scope=True,
- **kwargs):
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
- use_ddim = ddim_steps is not None
-
- log = dict()
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
- return_first_stage_outputs=True,
- force_c_encode=True,
- return_original_cond=True,
- bs=N)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- log["inputs"] = x
- log["reconstruction"] = xrec
- if self.model.conditioning_key is not None:
- if hasattr(self.cond_stage_model, "decode"):
- xc = self.cond_stage_model.decode(c)
- log["conditioning"] = xc
- elif self.cond_stage_key in ["caption", "txt"]:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
- log["conditioning"] = xc
- elif self.cond_stage_key in ['class_label', "cls"]:
- try:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
- log['conditioning'] = xc
- except KeyError:
- # probably no "human_label" in batch
- pass
- elif isimage(xc):
- log["conditioning"] = xc
- if ismap(xc):
- log["original_conditioning"] = self.to_rgb(xc)
-
- if plot_diffusion_rows:
- # get diffusion row
- diffusion_row = list()
- z_start = z[:n_row]
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(z_start)
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
- diffusion_row.append(self.decode_first_stage(z_noisy))
-
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
- log["diffusion_row"] = diffusion_grid
-
- if sample:
- # get denoise row
- with ema_scope("Sampling"):
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
- x_samples = self.decode_first_stage(samples)
- log["samples"] = x_samples
- if plot_denoise_rows:
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
- log["denoise_row"] = denoise_grid
-
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
- self.first_stage_model, IdentityFirstStage):
- # also display when quantizing x0 while sampling
- with ema_scope("Plotting Quantized Denoised"):
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
- quantize_denoised=True)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
- # quantize_denoised=True)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_x0_quantized"] = x_samples
-
- if unconditional_guidance_scale > 1.0:
- uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
- if self.model.conditioning_key == "crossattn-adm":
- uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
- with ema_scope("Sampling with classifier-free guidance"):
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=uc,
- )
- x_samples_cfg = self.decode_first_stage(samples_cfg)
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
-
- if inpaint:
- # make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
- mask = torch.ones(N, h, w).to(self.device)
- # zeros will be filled in
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
- mask = mask[:, None, ...]
- with ema_scope("Plotting Inpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_inpainting"] = x_samples
- log["mask"] = mask
-
- # outpaint
- mask = 1. - mask
- with ema_scope("Plotting Outpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_outpainting"] = x_samples
-
- if plot_progressive_rows:
- with ema_scope("Plotting Progressives"):
- img, progressives = self.progressive_denoising(c,
- shape=(self.channels, self.image_size, self.image_size),
- batch_size=N)
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
- log["progressive_row"] = prog_row
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.cond_stage_trainable:
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
- params = params + list(self.cond_stage_model.parameters())
- if self.learn_logvar:
- print('Diffusion model optimizing logvar')
- params.append(self.logvar)
- opt = torch.optim.AdamW(params, lr=lr)
- if self.use_scheduler:
- assert 'target' in self.scheduler_config
- scheduler = instantiate_from_config(self.scheduler_config)
-
- print("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- }]
- return [opt], scheduler
- return opt
-
- @torch.no_grad()
- def to_rgb(self, x):
- x = x.float()
- if not hasattr(self, "colorize"):
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
- x = nn.functional.conv2d(x, weight=self.colorize)
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
- return x
-
-
-# class DiffusionWrapper(pl.LightningModule):
-class DiffusionWrapper(torch.nn.Module):
- def __init__(self, diff_model_config, conditioning_key):
- super().__init__()
- self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
- self.diffusion_model = instantiate_from_config(diff_model_config)
- self.conditioning_key = conditioning_key
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
-
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
- if self.conditioning_key is None:
- out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
- elif self.conditioning_key == 'concat':
- xc = torch.cat([x] + c_concat, dim=1)
- out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
- elif self.conditioning_key == 'crossattn':
- if not self.sequential_cross_attn:
- cc = torch.cat(c_crossattn, 1)
- else:
- cc = c_crossattn
- if hasattr(self, "scripted_diffusion_model"):
- # TorchScript changes names of the arguments
- # with argument cc defined as context=cc scripted model will produce
- # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
- out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
- else:
- out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
- elif self.conditioning_key == 'hybrid':
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
- elif self.conditioning_key == 'hybrid-adm':
- assert c_adm is not None
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
- elif self.conditioning_key == 'crossattn-adm':
- assert c_adm is not None
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
- elif self.conditioning_key == 'adm':
- cc = c_crossattn[0]
- out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
- else:
- raise NotImplementedError()
-
- return out
-
-
-class LatentUpscaleDiffusion(LatentDiffusion):
- def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
- super().__init__(*args, **kwargs)
- # assumes that neither the cond_stage nor the low_scale_model contain trainable params
- assert not self.cond_stage_trainable
- self.instantiate_low_stage(low_scale_config)
- self.low_scale_key = low_scale_key
- self.noise_level_key = noise_level_key
-
- def instantiate_low_stage(self, config):
- model = instantiate_from_config(config)
- self.low_scale_model = model.eval()
- self.low_scale_model.train = disabled_train
- for param in self.low_scale_model.parameters():
- param.requires_grad = False
-
- @torch.no_grad()
- def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
- if not log_mode:
- z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
- else:
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
- x_low = batch[self.low_scale_key][:bs]
- x_low = rearrange(x_low, 'b h w c -> b c h w')
- x_low = x_low.to(memory_format=torch.contiguous_format).float()
- zx, noise_level = self.low_scale_model(x_low)
- if self.noise_level_key is not None:
- # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
- raise NotImplementedError('TODO')
-
- all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
- if log_mode:
- # TODO: maybe disable if too expensive
- x_low_rec = self.low_scale_model.decode(zx)
- return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
- return z, all_conds
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
- plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
- unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
- **kwargs):
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
- use_ddim = ddim_steps is not None
-
- log = dict()
- z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
- log_mode=True)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- log["inputs"] = x
- log["reconstruction"] = xrec
- log["x_lr"] = x_low
- log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
- if self.model.conditioning_key is not None:
- if hasattr(self.cond_stage_model, "decode"):
- xc = self.cond_stage_model.decode(c)
- log["conditioning"] = xc
- elif self.cond_stage_key in ["caption", "txt"]:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
- log["conditioning"] = xc
- elif self.cond_stage_key in ['class_label', 'cls']:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
- log['conditioning'] = xc
- elif isimage(xc):
- log["conditioning"] = xc
- if ismap(xc):
- log["original_conditioning"] = self.to_rgb(xc)
-
- if plot_diffusion_rows:
- # get diffusion row
- diffusion_row = list()
- z_start = z[:n_row]
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(z_start)
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
- diffusion_row.append(self.decode_first_stage(z_noisy))
-
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
- log["diffusion_row"] = diffusion_grid
-
- if sample:
- # get denoise row
- with ema_scope("Sampling"):
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
- x_samples = self.decode_first_stage(samples)
- log["samples"] = x_samples
- if plot_denoise_rows:
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
- log["denoise_row"] = denoise_grid
-
- if unconditional_guidance_scale > 1.0:
- uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
- # TODO explore better "unconditional" choices for the other keys
- # maybe guide away from empty text label and highest noise level and maximally degraded zx?
- uc = dict()
- for k in c:
- if k == "c_crossattn":
- assert isinstance(c[k], list) and len(c[k]) == 1
- uc[k] = [uc_tmp]
- elif k == "c_adm": # todo: only run with text-based guidance?
- assert isinstance(c[k], torch.Tensor)
- #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
- uc[k] = c[k]
- elif isinstance(c[k], list):
- uc[k] = [c[k][i] for i in range(len(c[k]))]
- else:
- uc[k] = c[k]
-
- with ema_scope("Sampling with classifier-free guidance"):
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=uc,
- )
- x_samples_cfg = self.decode_first_stage(samples_cfg)
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
-
- if plot_progressive_rows:
- with ema_scope("Plotting Progressives"):
- img, progressives = self.progressive_denoising(c,
- shape=(self.channels, self.image_size, self.image_size),
- batch_size=N)
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
- log["progressive_row"] = prog_row
-
- return log
-
-
-class LatentFinetuneDiffusion(LatentDiffusion):
- """
- Basis for different finetunas, such as inpainting or depth2image
- To disable finetuning mode, set finetune_keys to None
- """
-
- def __init__(self,
- concat_keys: tuple,
- finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
- "model_ema.diffusion_modelinput_blocks00weight"
- ),
- keep_finetune_dims=4,
- # if model was trained without concat mode before and we would like to keep these channels
- c_concat_log_start=None, # to log reconstruction of c_concat codes
- c_concat_log_end=None,
- *args, **kwargs
- ):
- ckpt_path = kwargs.pop("ckpt_path", None)
- ignore_keys = kwargs.pop("ignore_keys", list())
- super().__init__(*args, **kwargs)
- self.finetune_keys = finetune_keys
- self.concat_keys = concat_keys
- self.keep_dims = keep_finetune_dims
- self.c_concat_log_start = c_concat_log_start
- self.c_concat_log_end = c_concat_log_end
- if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
- if exists(ckpt_path):
- self.init_from_ckpt(ckpt_path, ignore_keys)
-
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
- sd = torch.load(path, map_location="cpu")
- if "state_dict" in list(sd.keys()):
- sd = sd["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
-
- # make it explicit, finetune by including extra input channels
- if exists(self.finetune_keys) and k in self.finetune_keys:
- new_entry = None
- for name, param in self.named_parameters():
- if name in self.finetune_keys:
- print(
- f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
- new_entry = torch.zeros_like(param) # zero init
- assert exists(new_entry), 'did not find matching parameter to modify'
- new_entry[:, :self.keep_dims, ...] = sd[k]
- sd[k] = new_entry
-
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
- sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
- print(f"Unexpected Keys: {unexpected}")
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
- use_ema_scope=True,
- **kwargs):
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
- use_ddim = ddim_steps is not None
-
- log = dict()
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
- c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- log["inputs"] = x
- log["reconstruction"] = xrec
- if self.model.conditioning_key is not None:
- if hasattr(self.cond_stage_model, "decode"):
- xc = self.cond_stage_model.decode(c)
- log["conditioning"] = xc
- elif self.cond_stage_key in ["caption", "txt"]:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
- log["conditioning"] = xc
- elif self.cond_stage_key in ['class_label', 'cls']:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
- log['conditioning'] = xc
- elif isimage(xc):
- log["conditioning"] = xc
- if ismap(xc):
- log["original_conditioning"] = self.to_rgb(xc)
-
- if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
- log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
-
- if plot_diffusion_rows:
- # get diffusion row
- diffusion_row = list()
- z_start = z[:n_row]
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(z_start)
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
- diffusion_row.append(self.decode_first_stage(z_noisy))
-
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
- log["diffusion_row"] = diffusion_grid
-
- if sample:
- # get denoise row
- with ema_scope("Sampling"):
- samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
- batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
- x_samples = self.decode_first_stage(samples)
- log["samples"] = x_samples
- if plot_denoise_rows:
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
- log["denoise_row"] = denoise_grid
-
- if unconditional_guidance_scale > 1.0:
- uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
- uc_cat = c_cat
- uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
- with ema_scope("Sampling with classifier-free guidance"):
- samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
- batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=uc_full,
- )
- x_samples_cfg = self.decode_first_stage(samples_cfg)
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
-
- return log
-
-
-class LatentInpaintDiffusion(LatentFinetuneDiffusion):
- """
- can either run as pure inpainting model (only concat mode) or with mixed conditionings,
- e.g. mask as concat and text via cross-attn.
- To disable finetuning mode, set finetune_keys to None
- """
-
- def __init__(self,
- concat_keys=("mask", "masked_image"),
- masked_image_key="masked_image",
- *args, **kwargs
- ):
- super().__init__(concat_keys, *args, **kwargs)
- self.masked_image_key = masked_image_key
- assert self.masked_image_key in concat_keys
-
- @torch.no_grad()
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
- # note: restricted to non-trainable encoders currently
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
-
- assert exists(self.concat_keys)
- c_cat = list()
- for ck in self.concat_keys:
- cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
- if bs is not None:
- cc = cc[:bs]
- cc = cc.to(self.device)
- bchw = z.shape
- if ck != self.masked_image_key:
- cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
- else:
- cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
- c_cat.append(cc)
- c_cat = torch.cat(c_cat, dim=1)
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
- if return_first_stage_outputs:
- return z, all_conds, x, xrec, xc
- return z, all_conds
-
- @torch.no_grad()
- def log_images(self, *args, **kwargs):
- log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
- log["masked_image"] = rearrange(args[0]["masked_image"],
- 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
- return log
-
-
-class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
- """
- condition on monocular depth estimation
- """
-
- def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
- self.depth_model = instantiate_from_config(depth_stage_config)
- self.depth_stage_key = concat_keys[0]
-
- @torch.no_grad()
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
- # note: restricted to non-trainable encoders currently
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
-
- assert exists(self.concat_keys)
- assert len(self.concat_keys) == 1
- c_cat = list()
- for ck in self.concat_keys:
- cc = batch[ck]
- if bs is not None:
- cc = cc[:bs]
- cc = cc.to(self.device)
- cc = self.depth_model(cc)
- cc = torch.nn.functional.interpolate(
- cc,
- size=z.shape[2:],
- mode="bicubic",
- align_corners=False,
- )
-
- depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
- keepdim=True)
- cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
- c_cat.append(cc)
- c_cat = torch.cat(c_cat, dim=1)
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
- if return_first_stage_outputs:
- return z, all_conds, x, xrec, xc
- return z, all_conds
-
- @torch.no_grad()
- def log_images(self, *args, **kwargs):
- log = super().log_images(*args, **kwargs)
- depth = self.depth_model(args[0][self.depth_stage_key])
- depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
- torch.amax(depth, dim=[1, 2, 3], keepdim=True)
- log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
- return log
-
-
-class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
- """
- condition on low-res image (and optionally on some spatial noise augmentation)
- """
- def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
- low_scale_config=None, low_scale_key=None, *args, **kwargs):
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
- self.reshuffle_patch_size = reshuffle_patch_size
- self.low_scale_model = None
- if low_scale_config is not None:
- print("Initializing a low-scale model")
- assert exists(low_scale_key)
- self.instantiate_low_stage(low_scale_config)
- self.low_scale_key = low_scale_key
-
- def instantiate_low_stage(self, config):
- model = instantiate_from_config(config)
- self.low_scale_model = model.eval()
- self.low_scale_model.train = disabled_train
- for param in self.low_scale_model.parameters():
- param.requires_grad = False
-
- @torch.no_grad()
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
- # note: restricted to non-trainable encoders currently
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
-
- assert exists(self.concat_keys)
- assert len(self.concat_keys) == 1
- # optionally make spatial noise_level here
- c_cat = list()
- noise_level = None
- for ck in self.concat_keys:
- cc = batch[ck]
- cc = rearrange(cc, 'b h w c -> b c h w')
- if exists(self.reshuffle_patch_size):
- assert isinstance(self.reshuffle_patch_size, int)
- cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
- p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
- if bs is not None:
- cc = cc[:bs]
- cc = cc.to(self.device)
- if exists(self.low_scale_model) and ck == self.low_scale_key:
- cc, noise_level = self.low_scale_model(cc)
- c_cat.append(cc)
- c_cat = torch.cat(c_cat, dim=1)
- if exists(noise_level):
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
- else:
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
- if return_first_stage_outputs:
- return z, all_conds, x, xrec, xc
- return z, all_conds
-
- @torch.no_grad()
- def log_images(self, *args, **kwargs):
- log = super().log_images(*args, **kwargs)
- log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
- return log
-
-
-class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
- def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
- freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.embed_key = embedding_key
- self.embedding_dropout = embedding_dropout
- # self._init_embedder(embedder_config, freeze_embedder)
- self._init_noise_aug(noise_aug_config)
-
- def _init_embedder(self, config, freeze=True):
- embedder = instantiate_from_config(config)
- if freeze:
- self.embedder = embedder.eval()
- self.embedder.train = disabled_train
- for param in self.embedder.parameters():
- param.requires_grad = False
-
- def _init_noise_aug(self, config):
- if config is not None:
- # use the KARLO schedule for noise augmentation on CLIP image embeddings
- noise_augmentor = instantiate_from_config(config)
- assert isinstance(noise_augmentor, nn.Module)
- noise_augmentor = noise_augmentor.eval()
- noise_augmentor.train = disabled_train
- self.noise_augmentor = noise_augmentor
- else:
- self.noise_augmentor = None
-
- def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
- outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
- z, c = outputs[0], outputs[1]
- img = batch[self.embed_key][:bs]
- img = rearrange(img, 'b h w c -> b c h w')
- c_adm = self.embedder(img)
- if self.noise_augmentor is not None:
- c_adm, noise_level_emb = self.noise_augmentor(c_adm)
- # assume this gives embeddings of noise levels
- c_adm = torch.cat((c_adm, noise_level_emb), 1)
- if self.training:
- c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
- device=c_adm.device)[:, None]) * c_adm
- all_conds = {"c_crossattn": [c], "c_adm": c_adm}
- noutputs = [z, all_conds]
- noutputs.extend(outputs[2:])
- return noutputs
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, **kwargs):
- log = dict()
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
- return_original_cond=True)
- log["inputs"] = x
- log["reconstruction"] = xrec
- assert self.model.conditioning_key is not None
- assert self.cond_stage_key in ["caption", "txt"]
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
- log["conditioning"] = xc
- uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
- unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
-
- uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
- ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
- with ema_scope(f"Sampling"):
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
- ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=uc_, )
- x_samples_cfg = self.decode_first_stage(samples_cfg)
- log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
- return log
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 573f4e1c6..a0d695693 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -10,6 +10,7 @@ from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
+import comfy.ops
from . import tomesd
@@ -50,9 +51,9 @@ def init_(tensor):
# feedforward
class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
+ def __init__(self, dim_in, dim_out, dtype=None):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
+ self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@@ -60,19 +61,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
- nn.Linear(dim, inner_dim),
+ comfy.ops.Linear(dim, inner_dim, dtype=dtype),
nn.GELU()
- ) if not glu else GEGLU(dim, inner_dim)
+ ) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
- nn.Linear(inner_dim, dim_out)
+ comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
)
def forward(self, x):
@@ -88,8 +89,8 @@ def zero_module(module):
return module
-def Normalize(in_channels):
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+def Normalize(in_channels, dtype=None):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
class SpatialSelfAttention(nn.Module):
@@ -146,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -154,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
+ comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
nn.Dropout(dropout)
)
@@ -243,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -251,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
+ comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
nn.Dropout(dropout)
)
@@ -341,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2)
class CrossAttention(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -349,12 +350,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
+ comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
nn.Dropout(dropout)
)
@@ -397,7 +398,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
@@ -407,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@@ -448,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -456,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads
self.dim_head = dim_head
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
+ self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@@ -506,26 +507,28 @@ else:
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
- disable_self_attn=False):
+ disable_self_attn=False, dtype=None):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
- context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
- heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
+ heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim, dtype=dtype)
+ self.norm2 = nn.LayerNorm(dim, dtype=dtype)
+ self.norm3 = nn.LayerNorm(dim, dtype=dtype)
self.checkpoint = checkpoint
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
- current_index = None
+ extra_options = {}
if "current_index" in transformer_options:
- current_index = transformer_options["current_index"]
+ extra_options["transformer_index"] = transformer_options["current_index"]
+ if "block_index" in transformer_options:
+ extra_options["block_index"] = transformer_options["block_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
@@ -544,7 +547,7 @@ class BasicTransformerBlock(nn.Module):
context_attn1 = n
value_attn1 = context_attn1
for p in patch:
- n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
+ n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
@@ -556,7 +559,7 @@ class BasicTransformerBlock(nn.Module):
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
- x = p(current_index, x)
+ x = p(x, extra_options)
n = self.norm2(x)
@@ -566,10 +569,15 @@ class BasicTransformerBlock(nn.Module):
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
- n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
+ n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
n = self.attn2(n, context=context_attn2, value=value_attn2)
+ if "attn2_output_patch" in transformer_patches:
+ patch = transformer_patches["attn2_output_patch"]
+ for p in patch:
+ n = p(n, extra_options)
+
x += n
x = self.ff(self.norm3(x)) + x
return x
@@ -587,35 +595,34 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
- use_checkpoint=True):
+ use_checkpoint=True, dtype=None):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels)
+ self.norm = Normalize(in_channels, dtype=dtype)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
- padding=0)
+ padding=0, dtype=dtype)
else:
- self.proj_in = nn.Linear(in_channels, inner_dim)
+ self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
for d in range(depth)]
)
if not use_linear:
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
- in_channels,
+ self.proj_out = nn.Conv2d(inner_dim,in_channels,
kernel_size=1,
stride=1,
- padding=0))
+ padding=0, dtype=dtype)
else:
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}):
@@ -631,6 +638,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
+ transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py
index 5aef23f33..e170f6779 100644
--- a/comfy/ldm/modules/diffusionmodules/openaimodel.py
+++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py
@@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels
@@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
)
else:
assert self.channels == self.out_channels
@@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
+ dtype=None
):
super().__init__()
self.channels = channels
@@ -219,19 +220,19 @@ class ResBlock(TimestepBlock):
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
- normalization(channels),
+ normalization(channels, dtype=dtype),
nn.SiLU(),
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
)
self.updown = up or down
if up:
- self.h_upd = Upsample(channels, False, dims)
- self.x_upd = Upsample(channels, False, dims)
+ self.h_upd = Upsample(channels, False, dims, dtype=dtype)
+ self.x_upd = Upsample(channels, False, dims, dtype=dtype)
elif down:
- self.h_upd = Downsample(channels, False, dims)
- self.x_upd = Downsample(channels, False, dims)
+ self.h_upd = Downsample(channels, False, dims, dtype=dtype)
+ self.x_upd = Downsample(channels, False, dims, dtype=dtype)
else:
self.h_upd = self.x_upd = nn.Identity()
@@ -239,15 +240,15 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
),
)
self.out_layers = nn.Sequential(
- normalization(self.out_channels),
+ normalization(self.out_channels, dtype=dtype),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
),
)
@@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
- dims, channels, self.out_channels, 3, padding=1
+ dims, channels, self.out_channels, 3, padding=1, dtype=dtype
)
else:
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
def forward(self, x, emb):
"""
@@ -558,9 +559,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
+ linear(model_channels, time_embed_dim, dtype=self.dtype),
nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
+ linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
)
if self.num_classes is not None:
@@ -573,9 +574,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
- linear(adm_in_channels, time_embed_dim),
+ linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
+ linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
)
)
else:
@@ -584,7 +585,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
)
]
)
@@ -603,6 +604,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype
)
]
ch = mult * model_channels
@@ -631,7 +633,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint
+ use_checkpoint=use_checkpoint, dtype=self.dtype
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -650,10 +652,11 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
+ dtype=self.dtype
)
if resblock_updown
else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
)
)
)
@@ -678,6 +681,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype
),
AttentionBlock(
ch,
@@ -688,7 +692,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint
+ use_checkpoint=use_checkpoint, dtype=self.dtype
),
ResBlock(
ch,
@@ -697,6 +701,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype
),
)
self._feature_size += ch
@@ -714,6 +719,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype
)
]
ch = model_channels * mult
@@ -742,7 +748,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint
+ use_checkpoint=use_checkpoint, dtype=self.dtype
)
)
if level and i == self.num_res_blocks[level]:
@@ -757,18 +763,19 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
+ dtype=self.dtype
)
if resblock_updown
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
- normalization(ch),
+ normalization(ch, dtype=self.dtype),
nn.SiLU(),
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py
index 82ea3f0a6..d890c8044 100644
--- a/comfy/ldm/modules/diffusionmodules/util.py
+++ b/comfy/ldm/modules/diffusionmodules/util.py
@@ -16,7 +16,7 @@ import numpy as np
from einops import repeat
from comfy.ldm.util import instantiate_from_config
-
+import comfy.ops
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
@@ -206,13 +206,13 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
-def normalization(channels):
+def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
- return GroupNorm32(32, channels)
+ return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
@@ -233,7 +233,7 @@ def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
- return nn.Conv2d(*args, **kwargs)
+ return comfy.ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
@@ -243,7 +243,7 @@ def linear(*args, **kwargs):
"""
Create a linear module.
"""
- return nn.Linear(*args, **kwargs)
+ return comfy.ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
diff --git a/comfy/ldm/modules/encoders/kornia_functions.py b/comfy/ldm/modules/encoders/kornia_functions.py
deleted file mode 100644
index 912314cd7..000000000
--- a/comfy/ldm/modules/encoders/kornia_functions.py
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-from typing import List, Tuple, Union
-
-import torch
-import torch.nn as nn
-
-#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
-
-def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
- r"""Normalize an image/video tensor with mean and standard deviation.
- .. math::
- \text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
- Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
- Args:
- data: Image tensor of size :math:`(B, C, *)`.
- mean: Mean for each channel.
- std: Standard deviations for each channel.
- Return:
- Normalised tensor with same size as input :math:`(B, C, *)`.
- Examples:
- >>> x = torch.rand(1, 4, 3, 3)
- >>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- >>> x = torch.rand(1, 4, 3, 3)
- >>> mean = torch.zeros(4)
- >>> std = 255. * torch.ones(4)
- >>> out = normalize(x, mean, std)
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- """
- shape = data.shape
- if len(mean.shape) == 0 or mean.shape[0] == 1:
- mean = mean.expand(shape[1])
- if len(std.shape) == 0 or std.shape[0] == 1:
- std = std.expand(shape[1])
-
- # Allow broadcast on channel dimension
- if mean.shape and mean.shape[0] != 1:
- if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
- raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
-
- # Allow broadcast on channel dimension
- if std.shape and std.shape[0] != 1:
- if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
- raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
-
- mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
- std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
-
- if mean.shape:
- mean = mean[..., :, None]
- if std.shape:
- std = std[..., :, None]
-
- out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
-
- return out.view(shape)
diff --git a/comfy/ldm/modules/encoders/modules.py b/comfy/ldm/modules/encoders/modules.py
deleted file mode 100644
index bc9fde638..000000000
--- a/comfy/ldm/modules/encoders/modules.py
+++ /dev/null
@@ -1,314 +0,0 @@
-import torch
-import torch.nn as nn
-from . import kornia_functions
-from torch.utils.checkpoint import checkpoint
-
-from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
-
-import open_clip
-from ldm.util import default, count_params
-
-
-class AbstractEncoder(nn.Module):
- def __init__(self):
- super().__init__()
-
- def encode(self, *args, **kwargs):
- raise NotImplementedError
-
-
-class IdentityEncoder(AbstractEncoder):
-
- def encode(self, x):
- return x
-
-
-class ClassEmbedder(nn.Module):
- def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
- super().__init__()
- self.key = key
- self.embedding = nn.Embedding(n_classes, embed_dim)
- self.n_classes = n_classes
- self.ucg_rate = ucg_rate
-
- def forward(self, batch, key=None, disable_dropout=False):
- if key is None:
- key = self.key
- # this is for use in crossattn
- c = batch[key][:, None]
- if self.ucg_rate > 0. and not disable_dropout:
- mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
- c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
- c = c.long()
- c = self.embedding(c)
- return c
-
- def get_unconditional_conditioning(self, bs, device="cuda"):
- uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
- uc = torch.ones((bs,), device=device) * uc_class
- uc = {self.key: uc}
- return uc
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-class FrozenT5Embedder(AbstractEncoder):
- """Uses the T5 transformer encoder for text"""
-
- def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
- freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
- super().__init__()
- self.tokenizer = T5Tokenizer.from_pretrained(version)
- self.transformer = T5EncoderModel.from_pretrained(version)
- self.device = device
- self.max_length = max_length # TODO: typical value?
- if freeze:
- self.freeze()
-
- def freeze(self):
- self.transformer = self.transformer.eval()
- # self.train = disabled_train
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- tokens = batch_encoding["input_ids"].to(self.device)
- outputs = self.transformer(input_ids=tokens)
-
- z = outputs.last_hidden_state
- return z
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenCLIPEmbedder(AbstractEncoder):
- """Uses the CLIP transformer encoder for text (from huggingface)"""
- LAYERS = [
- "last",
- "pooled",
- "hidden"
- ]
-
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
- freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
- super().__init__()
- assert layer in self.LAYERS
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
- self.transformer = CLIPTextModel.from_pretrained(version)
- self.device = device
- self.max_length = max_length
- if freeze:
- self.freeze()
- self.layer = layer
- self.layer_idx = layer_idx
- if layer == "hidden":
- assert layer_idx is not None
- assert 0 <= abs(layer_idx) <= 12
-
- def freeze(self):
- self.transformer = self.transformer.eval()
- # self.train = disabled_train
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- tokens = batch_encoding["input_ids"].to(self.device)
- outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
- if self.layer == "last":
- z = outputs.last_hidden_state
- elif self.layer == "pooled":
- z = outputs.pooler_output[:, None, :]
- else:
- z = outputs.hidden_states[self.layer_idx]
- return z
-
- def encode(self, text):
- return self(text)
-
-
-class ClipImageEmbedder(nn.Module):
- def __init__(
- self,
- model,
- jit=False,
- device='cuda' if torch.cuda.is_available() else 'cpu',
- antialias=True,
- ucg_rate=0.
- ):
- super().__init__()
- from clip import load as load_clip
- self.model, _ = load_clip(name=model, device=device, jit=jit)
-
- self.antialias = antialias
-
- self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
- self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
- self.ucg_rate = ucg_rate
-
- def preprocess(self, x):
- # normalize to [0,1]
- # x = kornia_functions.geometry_resize(x, (224, 224),
- # interpolation='bicubic', align_corners=True,
- # antialias=self.antialias)
- x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
- x = (x + 1.) / 2.
- # re-normalize according to clip
- x = kornia_functions.enhance_normalize(x, self.mean, self.std)
- return x
-
- def forward(self, x, no_dropout=False):
- # x is assumed to be in range [-1,1]
- out = self.model.encode_image(self.preprocess(x))
- out = out.to(x.dtype)
- if self.ucg_rate > 0. and not no_dropout:
- out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
- return out
-
-
-class FrozenOpenCLIPEmbedder(AbstractEncoder):
- """
- Uses the OpenCLIP transformer encoder for text
- """
- LAYERS = [
- # "pooled",
- "last",
- "penultimate"
- ]
-
- def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
- freeze=True, layer="last"):
- super().__init__()
- assert layer in self.LAYERS
- model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
- del model.visual
- self.model = model
-
- self.device = device
- self.max_length = max_length
- if freeze:
- self.freeze()
- self.layer = layer
- if self.layer == "last":
- self.layer_idx = 0
- elif self.layer == "penultimate":
- self.layer_idx = 1
- else:
- raise NotImplementedError()
-
- def freeze(self):
- self.model = self.model.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- tokens = open_clip.tokenize(text)
- z = self.encode_with_transformer(tokens.to(self.device))
- return z
-
- def encode_with_transformer(self, text):
- x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
- x = x + self.model.positional_embedding
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
- x = x.permute(1, 0, 2) # LND -> NLD
- x = self.model.ln_final(x)
- return x
-
- def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
- for i, r in enumerate(self.model.transformer.resblocks):
- if i == len(self.model.transformer.resblocks) - self.layer_idx:
- break
- if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
- x = checkpoint(r, x, attn_mask)
- else:
- x = r(x, attn_mask=attn_mask)
- return x
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
- """
- Uses the OpenCLIP vision transformer encoder for images
- """
-
- def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
- freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
- super().__init__()
- model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
- pretrained=version, )
- del model.transformer
- self.model = model
-
- self.device = device
- self.max_length = max_length
- if freeze:
- self.freeze()
- self.layer = layer
- if self.layer == "penultimate":
- raise NotImplementedError()
- self.layer_idx = 1
-
- self.antialias = antialias
-
- self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
- self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
- self.ucg_rate = ucg_rate
-
- def preprocess(self, x):
- # normalize to [0,1]
- # x = kornia.geometry.resize(x, (224, 224),
- # interpolation='bicubic', align_corners=True,
- # antialias=self.antialias)
- x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
- x = (x + 1.) / 2.
- # renormalize according to clip
- x = kornia_functions.enhance_normalize(x, self.mean, self.std)
- return x
-
- def freeze(self):
- self.model = self.model.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, image, no_dropout=False):
- z = self.encode_with_vision_transformer(image)
- if self.ucg_rate > 0. and not no_dropout:
- z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
- return z
-
- def encode_with_vision_transformer(self, img):
- img = self.preprocess(img)
- x = self.model.visual(img)
- return x
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenCLIPT5Encoder(AbstractEncoder):
- def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
- clip_max_length=77, t5_max_length=77):
- super().__init__()
- self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
- self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
- print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
- f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
-
- def encode(self, text):
- return self(text)
-
- def forward(self, text):
- clip_z = self.clip_encoder.encode(text)
- t5_z = self.t5_encoder.encode(text)
- return [clip_z, t5_z]
diff --git a/comfy/ldm/modules/image_degradation/__init__.py b/comfy/ldm/modules/image_degradation/__init__.py
deleted file mode 100644
index 7836cada8..000000000
--- a/comfy/ldm/modules/image_degradation/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
-from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/comfy/ldm/modules/image_degradation/bsrgan.py b/comfy/ldm/modules/image_degradation/bsrgan.py
deleted file mode 100644
index 32ef56169..000000000
--- a/comfy/ldm/modules/image_degradation/bsrgan.py
+++ /dev/null
@@ -1,730 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# From 2019/03--2021/08
-# --------------------------------------------
-"""
-
-import numpy as np
-import cv2
-import torch
-
-from functools import partial
-import random
-from scipy import ndimage
-import scipy
-import scipy.stats as ss
-from scipy.interpolate import interp2d
-from scipy.linalg import orth
-import albumentations
-
-import ldm.modules.image_degradation.utils_image as util
-
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def analytic_kernel(k):
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
- k_size = k.shape[0]
- # Calculate the big kernels size
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
- # Loop over the small kernel to fill the big one
- for r in range(k_size):
- for c in range(k_size):
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
- crop = k_size // 2
- cropped_big_k = big_k[crop:-crop, crop:-crop]
- # Normalize to 1
- return cropped_big_k / cropped_big_k.sum()
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf - 1) * 0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w - 1)
- y1 = np.clip(y1, 0, h - 1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-def blur(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, Nx1xhxw
- '''
- n, c = x.shape[:2]
- p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
- k = k.repeat(1, c, 1, 1)
- k = k.view(-1, 1, k.shape[2], k.shape[3])
- x = x.view(1, -1, x.shape[2], x.shape[3])
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
- x = x.view(n, c, x.shape[2], x.shape[3])
-
- return x
-
-
-def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z - MU
- ZZ_t = ZZ.transpose(0, 1, 3, 2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
- arg = -(x * x + y * y) / (2 * std * std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h / sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha, 1])])
- h1 = alpha / (alpha + 1)
- h2 = (1 - alpha) / (alpha + 1)
- h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
-
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1 / sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
- ''' bicubic downsampling + blur
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def add_sharpening(img, weight=0.5, radius=50, threshold=10):
- """USM sharpening. borrowed from real-ESRGAN
- Input image: I; Blurry image: B.
- 1. K = I + weight * (I - B)
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
- 3. Blur mask:
- 4. Out = Mask * K + (1 - Mask) * I
- Args:
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
- weight (float): Sharp weight. Default: 1.
- radius (float): Kernel size of Gaussian blur. Default: 50.
- threshold (int):
- """
- if radius % 2 == 0:
- radius += 1
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
- residual = img - blur
- mask = np.abs(residual) * 255 > threshold
- mask = mask.astype('float32')
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
-
- K = img + weight * residual
- K = np.clip(K, 0, 1)
- return soft_mask * K + (1 - soft_mask) * img
-
-
-def add_blur(img, sf=4):
- wd2 = 4.0 + sf
- wd = 2.0 + 0.2 * sf
- if random.random() < 0.5:
- l1 = wd2 * random.random()
- l2 = wd2 * random.random()
- k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
- else:
- k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
- img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
-
- return img
-
-
-def add_resize(img, sf=4):
- rnum = np.random.rand()
- if rnum > 0.8: # up
- sf1 = random.uniform(1, 2)
- elif rnum < 0.7: # down
- sf1 = random.uniform(0.5 / sf, 1)
- else:
- sf1 = 1.0
- img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- return img
-
-
-# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
-# noise_level = random.randint(noise_level1, noise_level2)
-# rnum = np.random.rand()
-# if rnum > 0.6: # add color Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
-# elif rnum < 0.4: # add grayscale Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
-# else: # add noise
-# L = noise_level2 / 255.
-# D = np.diag(np.random.rand(3))
-# U = orth(np.random.rand(3, 3))
-# conv = np.dot(np.dot(np.transpose(U), D), U)
-# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
-# img = np.clip(img, 0.0, 1.0)
-# return img
-
-def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- rnum = np.random.rand()
- if rnum > 0.6: # add color Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4: # add grayscale Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else: # add noise
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_speckle_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- img = np.clip(img, 0.0, 1.0)
- rnum = random.random()
- if rnum > 0.6:
- img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4:
- img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else:
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_Poisson_noise(img):
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
- vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
- if random.random() < 0.5:
- img = np.random.poisson(img * vals).astype(np.float32) / vals
- else:
- img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
- img += noise_gray[:, :, np.newaxis]
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_JPEG_noise(img):
- quality_factor = random.randint(30, 95)
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
- img = cv2.imdecode(encimg, 1)
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
- return img
-
-
-def random_crop(lq, hq, sf=4, lq_patchsize=64):
- h, w = lq.shape[:2]
- rnd_h = random.randint(0, h - lq_patchsize)
- rnd_w = random.randint(0, w - lq_patchsize)
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
-
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
- return lq, hq
-
-
-def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- hq = img.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- img = util.imresize_np(img, 1 / 2, True)
- img = np.clip(img, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- img = add_blur(img, sf=sf)
-
- elif i == 1:
- img = add_blur(img, sf=sf)
-
- elif i == 2:
- a, b = img.shape[1], img.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
- img = img[0::sf, 0::sf, ...] # nearest downsampling
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- img = add_JPEG_noise(img)
-
- elif i == 6:
- # add processed camera sensor noise
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
-
- return img, hq
-
-
-# todo no isp_model?
-def degradation_bsrgan_variant(image, sf=4, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- image = util.uint2single(image)
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = image.shape[:2]
- image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = image.shape[:2]
-
- hq = image.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- image = util.imresize_np(image, 1 / 2, True)
- image = np.clip(image, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- image = add_blur(image, sf=sf)
-
- elif i == 1:
- image = add_blur(image, sf=sf)
-
- elif i == 2:
- a, b = image.shape[1], image.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
- image = image[0::sf, 0::sf, ...] # nearest downsampling
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- image = add_JPEG_noise(image)
-
- # elif i == 6:
- # # add processed camera sensor noise
- # if random.random() < isp_prob and isp_model is not None:
- # with torch.no_grad():
- # img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- image = add_JPEG_noise(image)
- image = util.single2uint(image)
- example = {"image":image}
- return example
-
-
-# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
-def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
- """
- This is an extended degradation model by combining
- the degradation models of BSRGAN and Real-ESRGAN
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- use_shuffle: the degradation shuffle
- use_sharp: sharpening the img
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- if use_sharp:
- img = add_sharpening(img)
- hq = img.copy()
-
- if random.random() < shuffle_prob:
- shuffle_order = random.sample(range(13), 13)
- else:
- shuffle_order = list(range(13))
- # local shuffle for noise, JPEG is always the last one
- shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
- shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
-
- poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
-
- for i in shuffle_order:
- if i == 0:
- img = add_blur(img, sf=sf)
- elif i == 1:
- img = add_resize(img, sf=sf)
- elif i == 2:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 3:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 4:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 5:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- elif i == 6:
- img = add_JPEG_noise(img)
- elif i == 7:
- img = add_blur(img, sf=sf)
- elif i == 8:
- img = add_resize(img, sf=sf)
- elif i == 9:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 10:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 11:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 12:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- else:
- print('check the shuffle!')
-
- # resize to desired size
- img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
- interpolation=random.choice([1, 2, 3]))
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf, lq_patchsize)
-
- return img, hq
-
-
-if __name__ == '__main__':
- print("hey")
- img = util.imread_uint('utils/test.png', 3)
- print(img)
- img = util.uint2single(img)
- print(img)
- img = img[:448, :448]
- h = img.shape[0] // 4
- print("resizing to", h)
- sf = 4
- deg_fn = partial(degradation_bsrgan_variant, sf=sf)
- for i in range(20):
- print(i)
- img_lq = deg_fn(img)
- print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
- print(img_lq.shape)
- print("bicubic", img_lq_bicubic.shape)
- print(img_hq.shape)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i) + '.png')
-
-
diff --git a/comfy/ldm/modules/image_degradation/bsrgan_light.py b/comfy/ldm/modules/image_degradation/bsrgan_light.py
deleted file mode 100644
index 808c7f882..000000000
--- a/comfy/ldm/modules/image_degradation/bsrgan_light.py
+++ /dev/null
@@ -1,651 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import cv2
-import torch
-
-from functools import partial
-import random
-from scipy import ndimage
-import scipy
-import scipy.stats as ss
-from scipy.interpolate import interp2d
-from scipy.linalg import orth
-import albumentations
-
-import ldm.modules.image_degradation.utils_image as util
-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# From 2019/03--2021/08
-# --------------------------------------------
-"""
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def analytic_kernel(k):
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
- k_size = k.shape[0]
- # Calculate the big kernels size
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
- # Loop over the small kernel to fill the big one
- for r in range(k_size):
- for c in range(k_size):
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
- crop = k_size // 2
- cropped_big_k = big_k[crop:-crop, crop:-crop]
- # Normalize to 1
- return cropped_big_k / cropped_big_k.sum()
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf - 1) * 0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w - 1)
- y1 = np.clip(y1, 0, h - 1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-def blur(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, Nx1xhxw
- '''
- n, c = x.shape[:2]
- p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
- k = k.repeat(1, c, 1, 1)
- k = k.view(-1, 1, k.shape[2], k.shape[3])
- x = x.view(1, -1, x.shape[2], x.shape[3])
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
- x = x.view(n, c, x.shape[2], x.shape[3])
-
- return x
-
-
-def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z - MU
- ZZ_t = ZZ.transpose(0, 1, 3, 2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
- arg = -(x * x + y * y) / (2 * std * std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h / sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha, 1])])
- h1 = alpha / (alpha + 1)
- h2 = (1 - alpha) / (alpha + 1)
- h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
-
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1 / sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
- ''' bicubic downsampling + blur
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- '''
- x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def add_sharpening(img, weight=0.5, radius=50, threshold=10):
- """USM sharpening. borrowed from real-ESRGAN
- Input image: I; Blurry image: B.
- 1. K = I + weight * (I - B)
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
- 3. Blur mask:
- 4. Out = Mask * K + (1 - Mask) * I
- Args:
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
- weight (float): Sharp weight. Default: 1.
- radius (float): Kernel size of Gaussian blur. Default: 50.
- threshold (int):
- """
- if radius % 2 == 0:
- radius += 1
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
- residual = img - blur
- mask = np.abs(residual) * 255 > threshold
- mask = mask.astype('float32')
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
-
- K = img + weight * residual
- K = np.clip(K, 0, 1)
- return soft_mask * K + (1 - soft_mask) * img
-
-
-def add_blur(img, sf=4):
- wd2 = 4.0 + sf
- wd = 2.0 + 0.2 * sf
-
- wd2 = wd2/4
- wd = wd/4
-
- if random.random() < 0.5:
- l1 = wd2 * random.random()
- l2 = wd2 * random.random()
- k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
- else:
- k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
- img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
-
- return img
-
-
-def add_resize(img, sf=4):
- rnum = np.random.rand()
- if rnum > 0.8: # up
- sf1 = random.uniform(1, 2)
- elif rnum < 0.7: # down
- sf1 = random.uniform(0.5 / sf, 1)
- else:
- sf1 = 1.0
- img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- return img
-
-
-# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
-# noise_level = random.randint(noise_level1, noise_level2)
-# rnum = np.random.rand()
-# if rnum > 0.6: # add color Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
-# elif rnum < 0.4: # add grayscale Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
-# else: # add noise
-# L = noise_level2 / 255.
-# D = np.diag(np.random.rand(3))
-# U = orth(np.random.rand(3, 3))
-# conv = np.dot(np.dot(np.transpose(U), D), U)
-# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
-# img = np.clip(img, 0.0, 1.0)
-# return img
-
-def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- rnum = np.random.rand()
- if rnum > 0.6: # add color Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4: # add grayscale Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else: # add noise
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_speckle_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- img = np.clip(img, 0.0, 1.0)
- rnum = random.random()
- if rnum > 0.6:
- img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4:
- img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else:
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_Poisson_noise(img):
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
- vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
- if random.random() < 0.5:
- img = np.random.poisson(img * vals).astype(np.float32) / vals
- else:
- img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
- img += noise_gray[:, :, np.newaxis]
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_JPEG_noise(img):
- quality_factor = random.randint(80, 95)
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
- img = cv2.imdecode(encimg, 1)
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
- return img
-
-
-def random_crop(lq, hq, sf=4, lq_patchsize=64):
- h, w = lq.shape[:2]
- rnd_h = random.randint(0, h - lq_patchsize)
- rnd_w = random.randint(0, w - lq_patchsize)
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
-
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
- return lq, hq
-
-
-def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- hq = img.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- img = util.imresize_np(img, 1 / 2, True)
- img = np.clip(img, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- img = add_blur(img, sf=sf)
-
- elif i == 1:
- img = add_blur(img, sf=sf)
-
- elif i == 2:
- a, b = img.shape[1], img.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
- img = img[0::sf, 0::sf, ...] # nearest downsampling
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- img = add_JPEG_noise(img)
-
- elif i == 6:
- # add processed camera sensor noise
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
-
- return img, hq
-
-
-# todo no isp_model?
-def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- image = util.uint2single(image)
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = image.shape[:2]
- image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = image.shape[:2]
-
- hq = image.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- image = util.imresize_np(image, 1 / 2, True)
- image = np.clip(image, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- image = add_blur(image, sf=sf)
-
- # elif i == 1:
- # image = add_blur(image, sf=sf)
-
- if i == 0:
- pass
-
- elif i == 2:
- a, b = image.shape[1], image.shape[0]
- # downsample2
- if random.random() < 0.8:
- sf1 = random.uniform(1, 2 * sf)
- image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
- image = image[0::sf, 0::sf, ...] # nearest downsampling
-
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- image = add_JPEG_noise(image)
- #
- # elif i == 6:
- # # add processed camera sensor noise
- # if random.random() < isp_prob and isp_model is not None:
- # with torch.no_grad():
- # img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- image = add_JPEG_noise(image)
- image = util.single2uint(image)
- if up:
- image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
- example = {"image": image}
- return example
-
-
-
-
-if __name__ == '__main__':
- print("hey")
- img = util.imread_uint('utils/test.png', 3)
- img = img[:448, :448]
- h = img.shape[0] // 4
- print("resizing to", h)
- sf = 4
- deg_fn = partial(degradation_bsrgan_variant, sf=sf)
- for i in range(20):
- print(i)
- img_hq = img
- img_lq = deg_fn(img)["image"]
- img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
- print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
- print(img_lq.shape)
- print("bicubic", img_lq_bicubic.shape)
- print(img_hq.shape)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
- (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i) + '.png')
diff --git a/comfy/ldm/modules/image_degradation/utils/test.png b/comfy/ldm/modules/image_degradation/utils/test.png
deleted file mode 100644
index 4249b43de..000000000
Binary files a/comfy/ldm/modules/image_degradation/utils/test.png and /dev/null differ
diff --git a/comfy/ldm/modules/image_degradation/utils_image.py b/comfy/ldm/modules/image_degradation/utils_image.py
deleted file mode 100644
index 0175f155a..000000000
--- a/comfy/ldm/modules/image_degradation/utils_image.py
+++ /dev/null
@@ -1,916 +0,0 @@
-import os
-import math
-import random
-import numpy as np
-import torch
-import cv2
-from torchvision.utils import make_grid
-from datetime import datetime
-#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
-
-
-os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
-
-
-'''
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 03/Mar/2019
-# --------------------------------------------
-# https://github.com/twhui/SRGAN-pyTorch
-# https://github.com/xinntao/BasicSR
-# --------------------------------------------
-'''
-
-
-IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
-
-
-def is_image_file(filename):
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
-
-
-def get_timestamp():
- return datetime.now().strftime('%y%m%d-%H%M%S')
-
-
-def imshow(x, title=None, cbar=False, figsize=None):
- plt.figure(figsize=figsize)
- plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
- if title:
- plt.title(title)
- if cbar:
- plt.colorbar()
- plt.show()
-
-
-def surf(Z, cmap='rainbow', figsize=None):
- plt.figure(figsize=figsize)
- ax3 = plt.axes(projection='3d')
-
- w, h = Z.shape[:2]
- xx = np.arange(0,w,1)
- yy = np.arange(0,h,1)
- X, Y = np.meshgrid(xx, yy)
- ax3.plot_surface(X,Y,Z,cmap=cmap)
- #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
- plt.show()
-
-
-'''
-# --------------------------------------------
-# get image pathes
-# --------------------------------------------
-'''
-
-
-def get_image_paths(dataroot):
- paths = None # return None if dataroot is None
- if dataroot is not None:
- paths = sorted(_get_paths_from_images(dataroot))
- return paths
-
-
-def _get_paths_from_images(path):
- assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
- images = []
- for dirpath, _, fnames in sorted(os.walk(path)):
- for fname in sorted(fnames):
- if is_image_file(fname):
- img_path = os.path.join(dirpath, fname)
- images.append(img_path)
- assert images, '{:s} has no valid image file'.format(path)
- return images
-
-
-'''
-# --------------------------------------------
-# split large images into small images
-# --------------------------------------------
-'''
-
-
-def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
- w, h = img.shape[:2]
- patches = []
- if w > p_max and h > p_max:
- w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
- h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
- w1.append(w-p_size)
- h1.append(h-p_size)
-# print(w1)
-# print(h1)
- for i in w1:
- for j in h1:
- patches.append(img[i:i+p_size, j:j+p_size,:])
- else:
- patches.append(img)
-
- return patches
-
-
-def imssave(imgs, img_path):
- """
- imgs: list, N images of size WxHxC
- """
- img_name, ext = os.path.splitext(os.path.basename(img_path))
-
- for i, img in enumerate(imgs):
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
- cv2.imwrite(new_path, img)
-
-
-def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
- """
- split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
- and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
- will be splitted.
- Args:
- original_dataroot:
- taget_dataroot:
- p_size: size of small images
- p_overlap: patch size in training is a good choice
- p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
- """
- paths = get_image_paths(original_dataroot)
- for img_path in paths:
- # img_name, ext = os.path.splitext(os.path.basename(img_path))
- img = imread_uint(img_path, n_channels=n_channels)
- patches = patches_from_image(img, p_size, p_overlap, p_max)
- imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
- #if original_dataroot == taget_dataroot:
- #del img_path
-
-'''
-# --------------------------------------------
-# makedir
-# --------------------------------------------
-'''
-
-
-def mkdir(path):
- if not os.path.exists(path):
- os.makedirs(path)
-
-
-def mkdirs(paths):
- if isinstance(paths, str):
- mkdir(paths)
- else:
- for path in paths:
- mkdir(path)
-
-
-def mkdir_and_rename(path):
- if os.path.exists(path):
- new_name = path + '_archived_' + get_timestamp()
- print('Path already exists. Rename it to [{:s}]'.format(new_name))
- os.rename(path, new_name)
- os.makedirs(path)
-
-
-'''
-# --------------------------------------------
-# read image from path
-# opencv is fast, but read BGR numpy image
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# get uint8 image of size HxWxn_channles (RGB)
-# --------------------------------------------
-def imread_uint(path, n_channels=3):
- # input: path
- # output: HxWx3(RGB or GGG), or HxWx1 (G)
- if n_channels == 1:
- img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
- img = np.expand_dims(img, axis=2) # HxWx1
- elif n_channels == 3:
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
- if img.ndim == 2:
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
- else:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
- return img
-
-
-# --------------------------------------------
-# matlab's imwrite
-# --------------------------------------------
-def imsave(img, img_path):
- img = np.squeeze(img)
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(img_path, img)
-
-def imwrite(img, img_path):
- img = np.squeeze(img)
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(img_path, img)
-
-
-
-# --------------------------------------------
-# get single image of size HxWxn_channles (BGR)
-# --------------------------------------------
-def read_img(path):
- # read image by cv2
- # return: Numpy float32, HWC, BGR, [0,1]
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
- img = img.astype(np.float32) / 255.
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- # some images have 4 channels
- if img.shape[2] > 3:
- img = img[:, :, :3]
- return img
-
-
-'''
-# --------------------------------------------
-# image format conversion
-# --------------------------------------------
-# numpy(single) <---> numpy(unit)
-# numpy(single) <---> tensor
-# numpy(unit) <---> tensor
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# numpy(single) [0, 1] <---> numpy(unit)
-# --------------------------------------------
-
-
-def uint2single(img):
-
- return np.float32(img/255.)
-
-
-def single2uint(img):
-
- return np.uint8((img.clip(0, 1)*255.).round())
-
-
-def uint162single(img):
-
- return np.float32(img/65535.)
-
-
-def single2uint16(img):
-
- return np.uint16((img.clip(0, 1)*65535.).round())
-
-
-# --------------------------------------------
-# numpy(unit) (HxWxC or HxW) <---> tensor
-# --------------------------------------------
-
-
-# convert uint to 4-dimensional torch tensor
-def uint2tensor4(img):
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
-
-
-# convert uint to 3-dimensional torch tensor
-def uint2tensor3(img):
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
-
-
-# convert 2/3/4-dimensional torch tensor to uint
-def tensor2uint(img):
- img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
- return np.uint8((img*255.0).round())
-
-
-# --------------------------------------------
-# numpy(single) (HxWxC) <---> tensor
-# --------------------------------------------
-
-
-# convert single (HxWxC) to 3-dimensional torch tensor
-def single2tensor3(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
-
-
-# convert single (HxWxC) to 4-dimensional torch tensor
-def single2tensor4(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
-
-
-# convert torch tensor to single
-def tensor2single(img):
- img = img.data.squeeze().float().cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
-
- return img
-
-# convert torch tensor to single
-def tensor2single3(img):
- img = img.data.squeeze().float().cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
- elif img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return img
-
-
-def single2tensor5(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
-
-
-def single32tensor5(img):
- return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
-
-
-def single42tensor4(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
-
-
-# from skimage.io import imread, imsave
-def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
- '''
- Converts a torch Tensor into an image Numpy array of BGR channel order
- Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
- Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
- '''
- tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
- n_dim = tensor.dim()
- if n_dim == 4:
- n_img = len(tensor)
- img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
- elif n_dim == 3:
- img_np = tensor.numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
- elif n_dim == 2:
- img_np = tensor.numpy()
- else:
- raise TypeError(
- 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
- if out_type == np.uint8:
- img_np = (img_np * 255.0).round()
- # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
- return img_np.astype(out_type)
-
-
-'''
-# --------------------------------------------
-# Augmentation, flipe and/or rotate
-# --------------------------------------------
-# The following two are enough.
-# (1) augmet_img: numpy image of WxHxC or WxH
-# (2) augment_img_tensor4: tensor image 1xCxWxH
-# --------------------------------------------
-'''
-
-
-def augment_img(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- if mode == 0:
- return img
- elif mode == 1:
- return np.flipud(np.rot90(img))
- elif mode == 2:
- return np.flipud(img)
- elif mode == 3:
- return np.rot90(img, k=3)
- elif mode == 4:
- return np.flipud(np.rot90(img, k=2))
- elif mode == 5:
- return np.rot90(img)
- elif mode == 6:
- return np.rot90(img, k=2)
- elif mode == 7:
- return np.flipud(np.rot90(img, k=3))
-
-
-def augment_img_tensor4(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- if mode == 0:
- return img
- elif mode == 1:
- return img.rot90(1, [2, 3]).flip([2])
- elif mode == 2:
- return img.flip([2])
- elif mode == 3:
- return img.rot90(3, [2, 3])
- elif mode == 4:
- return img.rot90(2, [2, 3]).flip([2])
- elif mode == 5:
- return img.rot90(1, [2, 3])
- elif mode == 6:
- return img.rot90(2, [2, 3])
- elif mode == 7:
- return img.rot90(3, [2, 3]).flip([2])
-
-
-def augment_img_tensor(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- img_size = img.size()
- img_np = img.data.cpu().numpy()
- if len(img_size) == 3:
- img_np = np.transpose(img_np, (1, 2, 0))
- elif len(img_size) == 4:
- img_np = np.transpose(img_np, (2, 3, 1, 0))
- img_np = augment_img(img_np, mode=mode)
- img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
- if len(img_size) == 3:
- img_tensor = img_tensor.permute(2, 0, 1)
- elif len(img_size) == 4:
- img_tensor = img_tensor.permute(3, 2, 0, 1)
-
- return img_tensor.type_as(img)
-
-
-def augment_img_np3(img, mode=0):
- if mode == 0:
- return img
- elif mode == 1:
- return img.transpose(1, 0, 2)
- elif mode == 2:
- return img[::-1, :, :]
- elif mode == 3:
- img = img[::-1, :, :]
- img = img.transpose(1, 0, 2)
- return img
- elif mode == 4:
- return img[:, ::-1, :]
- elif mode == 5:
- img = img[:, ::-1, :]
- img = img.transpose(1, 0, 2)
- return img
- elif mode == 6:
- img = img[:, ::-1, :]
- img = img[::-1, :, :]
- return img
- elif mode == 7:
- img = img[:, ::-1, :]
- img = img[::-1, :, :]
- img = img.transpose(1, 0, 2)
- return img
-
-
-def augment_imgs(img_list, hflip=True, rot=True):
- # horizontal flip OR rotate
- hflip = hflip and random.random() < 0.5
- vflip = rot and random.random() < 0.5
- rot90 = rot and random.random() < 0.5
-
- def _augment(img):
- if hflip:
- img = img[:, ::-1, :]
- if vflip:
- img = img[::-1, :, :]
- if rot90:
- img = img.transpose(1, 0, 2)
- return img
-
- return [_augment(img) for img in img_list]
-
-
-'''
-# --------------------------------------------
-# modcrop and shave
-# --------------------------------------------
-'''
-
-
-def modcrop(img_in, scale):
- # img_in: Numpy, HWC or HW
- img = np.copy(img_in)
- if img.ndim == 2:
- H, W = img.shape
- H_r, W_r = H % scale, W % scale
- img = img[:H - H_r, :W - W_r]
- elif img.ndim == 3:
- H, W, C = img.shape
- H_r, W_r = H % scale, W % scale
- img = img[:H - H_r, :W - W_r, :]
- else:
- raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
- return img
-
-
-def shave(img_in, border=0):
- # img_in: Numpy, HWC or HW
- img = np.copy(img_in)
- h, w = img.shape[:2]
- img = img[border:h-border, border:w-border]
- return img
-
-
-'''
-# --------------------------------------------
-# image processing process on numpy image
-# channel_convert(in_c, tar_type, img_list):
-# rgb2ycbcr(img, only_y=True):
-# bgr2ycbcr(img, only_y=True):
-# ycbcr2rgb(img):
-# --------------------------------------------
-'''
-
-
-def rgb2ycbcr(img, only_y=True):
- '''same as matlab rgb2ycbcr
- only_y: only return Y channel
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- if only_y:
- rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
- else:
- rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
- [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def ycbcr2rgb(img):
- '''same as matlab ycbcr2rgb
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
- [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def bgr2ycbcr(img, only_y=True):
- '''bgr version of rgb2ycbcr
- only_y: only return Y channel
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- if only_y:
- rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
- else:
- rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
- [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def channel_convert(in_c, tar_type, img_list):
- # conversion among BGR, gray and y
- if in_c == 3 and tar_type == 'gray': # BGR to gray
- gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
- return [np.expand_dims(img, axis=2) for img in gray_list]
- elif in_c == 3 and tar_type == 'y': # BGR to y
- y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
- return [np.expand_dims(img, axis=2) for img in y_list]
- elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
- return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
- else:
- return img_list
-
-
-'''
-# --------------------------------------------
-# metric, PSNR and SSIM
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# PSNR
-# --------------------------------------------
-def calculate_psnr(img1, img2, border=0):
- # img1 and img2 have range [0, 255]
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
- if not img1.shape == img2.shape:
- raise ValueError('Input images must have the same dimensions.')
- h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
- mse = np.mean((img1 - img2)**2)
- if mse == 0:
- return float('inf')
- return 20 * math.log10(255.0 / math.sqrt(mse))
-
-
-# --------------------------------------------
-# SSIM
-# --------------------------------------------
-def calculate_ssim(img1, img2, border=0):
- '''calculate SSIM
- the same outputs as MATLAB's
- img1, img2: [0, 255]
- '''
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
- if not img1.shape == img2.shape:
- raise ValueError('Input images must have the same dimensions.')
- h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
-
- if img1.ndim == 2:
- return ssim(img1, img2)
- elif img1.ndim == 3:
- if img1.shape[2] == 3:
- ssims = []
- for i in range(3):
- ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
- return np.array(ssims).mean()
- elif img1.shape[2] == 1:
- return ssim(np.squeeze(img1), np.squeeze(img2))
- else:
- raise ValueError('Wrong input image dimensions.')
-
-
-def ssim(img1, img2):
- C1 = (0.01 * 255)**2
- C2 = (0.03 * 255)**2
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
- kernel = cv2.getGaussianKernel(11, 1.5)
- window = np.outer(kernel, kernel.transpose())
-
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
- mu1_sq = mu1**2
- mu2_sq = mu2**2
- mu1_mu2 = mu1 * mu2
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
-
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
- (sigma1_sq + sigma2_sq + C2))
- return ssim_map.mean()
-
-
-'''
-# --------------------------------------------
-# matlab's bicubic imresize (numpy and torch) [0, 1]
-# --------------------------------------------
-'''
-
-
-# matlab 'imresize' function, now only support 'bicubic'
-def cubic(x):
- absx = torch.abs(x)
- absx2 = absx**2
- absx3 = absx**3
- return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
- (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
-
-
-def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
- if (scale < 1) and (antialiasing):
- # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
- kernel_width = kernel_width / scale
-
- # Output-space coordinates
- x = torch.linspace(1, out_length, out_length)
-
- # Input-space coordinates. Calculate the inverse mapping such that 0.5
- # in output space maps to 0.5 in input space, and 0.5+scale in output
- # space maps to 1.5 in input space.
- u = x / scale + 0.5 * (1 - 1 / scale)
-
- # What is the left-most pixel that can be involved in the computation?
- left = torch.floor(u - kernel_width / 2)
-
- # What is the maximum number of pixels that can be involved in the
- # computation? Note: it's OK to use an extra pixel here; if the
- # corresponding weights are all zero, it will be eliminated at the end
- # of this function.
- P = math.ceil(kernel_width) + 2
-
- # The indices of the input pixels involved in computing the k-th output
- # pixel are in row k of the indices matrix.
- indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
- 1, P).expand(out_length, P)
-
- # The weights used to compute the k-th output pixel are in row k of the
- # weights matrix.
- distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
- # apply cubic kernel
- if (scale < 1) and (antialiasing):
- weights = scale * cubic(distance_to_center * scale)
- else:
- weights = cubic(distance_to_center)
- # Normalize the weights matrix so that each row sums to 1.
- weights_sum = torch.sum(weights, 1).view(out_length, 1)
- weights = weights / weights_sum.expand(out_length, P)
-
- # If a column in weights is all zero, get rid of it. only consider the first and last column.
- weights_zero_tmp = torch.sum((weights == 0), 0)
- if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
- indices = indices.narrow(1, 1, P - 2)
- weights = weights.narrow(1, 1, P - 2)
- if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
- indices = indices.narrow(1, 0, P - 2)
- weights = weights.narrow(1, 0, P - 2)
- weights = weights.contiguous()
- indices = indices.contiguous()
- sym_len_s = -indices.min() + 1
- sym_len_e = indices.max() - in_length
- indices = indices + sym_len_s - 1
- return weights, indices, int(sym_len_s), int(sym_len_e)
-
-
-# --------------------------------------------
-# imresize for tensor image [0, 1]
-# --------------------------------------------
-def imresize(img, scale, antialiasing=True):
- # Now the scale should be the same for H and W
- # input: img: pytorch tensor, CHW or HW [0,1]
- # output: CHW or HW [0,1] w/o round
- need_squeeze = True if img.dim() == 2 else False
- if need_squeeze:
- img.unsqueeze_(0)
- in_C, in_H, in_W = img.size()
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
- kernel_width = 4
- kernel = 'cubic'
-
- # Return the desired dimension order for performing the resize. The
- # strategy is to perform the resize first along the dimension with the
- # smallest scale factor.
- # Now we do not support this.
-
- # get weights and indices
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
- # process H dimension
- # symmetric copying
- img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
- img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
-
- sym_patch = img[:, :sym_len_Hs, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
-
- sym_patch = img[:, -sym_len_He:, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
-
- out_1 = torch.FloatTensor(in_C, out_H, in_W)
- kernel_width = weights_H.size(1)
- for i in range(out_H):
- idx = int(indices_H[i][0])
- for j in range(out_C):
- out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
-
- # process W dimension
- # symmetric copying
- out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
- out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
-
- sym_patch = out_1[:, :, :sym_len_Ws]
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
- out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
-
- sym_patch = out_1[:, :, -sym_len_We:]
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
- out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
-
- out_2 = torch.FloatTensor(in_C, out_H, out_W)
- kernel_width = weights_W.size(1)
- for i in range(out_W):
- idx = int(indices_W[i][0])
- for j in range(out_C):
- out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
- if need_squeeze:
- out_2.squeeze_()
- return out_2
-
-
-# --------------------------------------------
-# imresize for numpy image [0, 1]
-# --------------------------------------------
-def imresize_np(img, scale, antialiasing=True):
- # Now the scale should be the same for H and W
- # input: img: Numpy, HWC or HW [0,1]
- # output: HWC or HW [0,1] w/o round
- img = torch.from_numpy(img)
- need_squeeze = True if img.dim() == 2 else False
- if need_squeeze:
- img.unsqueeze_(2)
-
- in_H, in_W, in_C = img.size()
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
- kernel_width = 4
- kernel = 'cubic'
-
- # Return the desired dimension order for performing the resize. The
- # strategy is to perform the resize first along the dimension with the
- # smallest scale factor.
- # Now we do not support this.
-
- # get weights and indices
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
- # process H dimension
- # symmetric copying
- img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
- img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
-
- sym_patch = img[:sym_len_Hs, :, :]
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
- img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
-
- sym_patch = img[-sym_len_He:, :, :]
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
- img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
-
- out_1 = torch.FloatTensor(out_H, in_W, in_C)
- kernel_width = weights_H.size(1)
- for i in range(out_H):
- idx = int(indices_H[i][0])
- for j in range(out_C):
- out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
-
- # process W dimension
- # symmetric copying
- out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
- out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
-
- sym_patch = out_1[:, :sym_len_Ws, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
-
- sym_patch = out_1[:, -sym_len_We:, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
-
- out_2 = torch.FloatTensor(out_H, out_W, in_C)
- kernel_width = weights_W.size(1)
- for i in range(out_W):
- idx = int(indices_W[i][0])
- for j in range(out_C):
- out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
- if need_squeeze:
- out_2.squeeze_()
-
- return out_2.numpy()
-
-
-if __name__ == '__main__':
- print('---')
-# img = imread_uint('test.bmp', 3)
-# img = uint2single(img)
-# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/comfy/ldm/modules/midas/__init__.py b/comfy/ldm/modules/midas/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/comfy/ldm/modules/midas/api.py b/comfy/ldm/modules/midas/api.py
deleted file mode 100644
index b58ebbffd..000000000
--- a/comfy/ldm/modules/midas/api.py
+++ /dev/null
@@ -1,170 +0,0 @@
-# based on https://github.com/isl-org/MiDaS
-
-import cv2
-import torch
-import torch.nn as nn
-from torchvision.transforms import Compose
-
-from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
-from ldm.modules.midas.midas.midas_net import MidasNet
-from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
-from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
-
-
-ISL_PATHS = {
- "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
- "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
- "midas_v21": "",
- "midas_v21_small": "",
-}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-def load_midas_transform(model_type):
- # https://github.com/isl-org/MiDaS/blob/master/run.py
- # load transform only
- if model_type == "dpt_large": # DPT-Large
- net_w, net_h = 384, 384
- resize_mode = "minimal"
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
-
- elif model_type == "dpt_hybrid": # DPT-Hybrid
- net_w, net_h = 384, 384
- resize_mode = "minimal"
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
-
- elif model_type == "midas_v21":
- net_w, net_h = 384, 384
- resize_mode = "upper_bound"
- normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-
- elif model_type == "midas_v21_small":
- net_w, net_h = 256, 256
- resize_mode = "upper_bound"
- normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
-
- else:
- assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
-
- transform = Compose(
- [
- Resize(
- net_w,
- net_h,
- resize_target=None,
- keep_aspect_ratio=True,
- ensure_multiple_of=32,
- resize_method=resize_mode,
- image_interpolation_method=cv2.INTER_CUBIC,
- ),
- normalization,
- PrepareForNet(),
- ]
- )
-
- return transform
-
-
-def load_model(model_type):
- # https://github.com/isl-org/MiDaS/blob/master/run.py
- # load network
- model_path = ISL_PATHS[model_type]
- if model_type == "dpt_large": # DPT-Large
- model = DPTDepthModel(
- path=model_path,
- backbone="vitl16_384",
- non_negative=True,
- )
- net_w, net_h = 384, 384
- resize_mode = "minimal"
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
-
- elif model_type == "dpt_hybrid": # DPT-Hybrid
- model = DPTDepthModel(
- path=model_path,
- backbone="vitb_rn50_384",
- non_negative=True,
- )
- net_w, net_h = 384, 384
- resize_mode = "minimal"
- normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
-
- elif model_type == "midas_v21":
- model = MidasNet(model_path, non_negative=True)
- net_w, net_h = 384, 384
- resize_mode = "upper_bound"
- normalization = NormalizeImage(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- )
-
- elif model_type == "midas_v21_small":
- model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
- non_negative=True, blocks={'expand': True})
- net_w, net_h = 256, 256
- resize_mode = "upper_bound"
- normalization = NormalizeImage(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- )
-
- else:
- print(f"model_type '{model_type}' not implemented, use: --model_type large")
- assert False
-
- transform = Compose(
- [
- Resize(
- net_w,
- net_h,
- resize_target=None,
- keep_aspect_ratio=True,
- ensure_multiple_of=32,
- resize_method=resize_mode,
- image_interpolation_method=cv2.INTER_CUBIC,
- ),
- normalization,
- PrepareForNet(),
- ]
- )
-
- return model.eval(), transform
-
-
-class MiDaSInference(nn.Module):
- MODEL_TYPES_TORCH_HUB = [
- "DPT_Large",
- "DPT_Hybrid",
- "MiDaS_small"
- ]
- MODEL_TYPES_ISL = [
- "dpt_large",
- "dpt_hybrid",
- "midas_v21",
- "midas_v21_small",
- ]
-
- def __init__(self, model_type):
- super().__init__()
- assert (model_type in self.MODEL_TYPES_ISL)
- model, _ = load_model(model_type)
- self.model = model
- self.model.train = disabled_train
-
- def forward(self, x):
- # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
- # NOTE: we expect that the correct transform has been called during dataloading.
- with torch.no_grad():
- prediction = self.model(x)
- prediction = torch.nn.functional.interpolate(
- prediction.unsqueeze(1),
- size=x.shape[2:],
- mode="bicubic",
- align_corners=False,
- )
- assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
- return prediction
-
diff --git a/comfy/ldm/modules/midas/midas/__init__.py b/comfy/ldm/modules/midas/midas/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/comfy/ldm/modules/midas/midas/base_model.py b/comfy/ldm/modules/midas/midas/base_model.py
deleted file mode 100644
index 5cf430239..000000000
--- a/comfy/ldm/modules/midas/midas/base_model.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import torch
-
-
-class BaseModel(torch.nn.Module):
- def load(self, path):
- """Load model from file.
-
- Args:
- path (str): file path
- """
- parameters = torch.load(path, map_location=torch.device('cpu'))
-
- if "optimizer" in parameters:
- parameters = parameters["model"]
-
- self.load_state_dict(parameters)
diff --git a/comfy/ldm/modules/midas/midas/blocks.py b/comfy/ldm/modules/midas/midas/blocks.py
deleted file mode 100644
index 2145d18fa..000000000
--- a/comfy/ldm/modules/midas/midas/blocks.py
+++ /dev/null
@@ -1,342 +0,0 @@
-import torch
-import torch.nn as nn
-
-from .vit import (
- _make_pretrained_vitb_rn50_384,
- _make_pretrained_vitl16_384,
- _make_pretrained_vitb16_384,
- forward_vit,
-)
-
-def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
- if backbone == "vitl16_384":
- pretrained = _make_pretrained_vitl16_384(
- use_pretrained, hooks=hooks, use_readout=use_readout
- )
- scratch = _make_scratch(
- [256, 512, 1024, 1024], features, groups=groups, expand=expand
- ) # ViT-L/16 - 85.0% Top1 (backbone)
- elif backbone == "vitb_rn50_384":
- pretrained = _make_pretrained_vitb_rn50_384(
- use_pretrained,
- hooks=hooks,
- use_vit_only=use_vit_only,
- use_readout=use_readout,
- )
- scratch = _make_scratch(
- [256, 512, 768, 768], features, groups=groups, expand=expand
- ) # ViT-H/16 - 85.0% Top1 (backbone)
- elif backbone == "vitb16_384":
- pretrained = _make_pretrained_vitb16_384(
- use_pretrained, hooks=hooks, use_readout=use_readout
- )
- scratch = _make_scratch(
- [96, 192, 384, 768], features, groups=groups, expand=expand
- ) # ViT-B/16 - 84.6% Top1 (backbone)
- elif backbone == "resnext101_wsl":
- pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
- scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
- elif backbone == "efficientnet_lite3":
- pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
- scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
- else:
- print(f"Backbone '{backbone}' not implemented")
- assert False
-
- return pretrained, scratch
-
-
-def _make_scratch(in_shape, out_shape, groups=1, expand=False):
- scratch = nn.Module()
-
- out_shape1 = out_shape
- out_shape2 = out_shape
- out_shape3 = out_shape
- out_shape4 = out_shape
- if expand==True:
- out_shape1 = out_shape
- out_shape2 = out_shape*2
- out_shape3 = out_shape*4
- out_shape4 = out_shape*8
-
- scratch.layer1_rn = nn.Conv2d(
- in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
- )
- scratch.layer2_rn = nn.Conv2d(
- in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
- )
- scratch.layer3_rn = nn.Conv2d(
- in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
- )
- scratch.layer4_rn = nn.Conv2d(
- in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
- )
-
- return scratch
-
-
-def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
- efficientnet = torch.hub.load(
- "rwightman/gen-efficientnet-pytorch",
- "tf_efficientnet_lite3",
- pretrained=use_pretrained,
- exportable=exportable
- )
- return _make_efficientnet_backbone(efficientnet)
-
-
-def _make_efficientnet_backbone(effnet):
- pretrained = nn.Module()
-
- pretrained.layer1 = nn.Sequential(
- effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
- )
- pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
- pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
- pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
-
- return pretrained
-
-
-def _make_resnet_backbone(resnet):
- pretrained = nn.Module()
- pretrained.layer1 = nn.Sequential(
- resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
- )
-
- pretrained.layer2 = resnet.layer2
- pretrained.layer3 = resnet.layer3
- pretrained.layer4 = resnet.layer4
-
- return pretrained
-
-
-def _make_pretrained_resnext101_wsl(use_pretrained):
- resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
- return _make_resnet_backbone(resnet)
-
-
-
-class Interpolate(nn.Module):
- """Interpolation module.
- """
-
- def __init__(self, scale_factor, mode, align_corners=False):
- """Init.
-
- Args:
- scale_factor (float): scaling
- mode (str): interpolation mode
- """
- super(Interpolate, self).__init__()
-
- self.interp = nn.functional.interpolate
- self.scale_factor = scale_factor
- self.mode = mode
- self.align_corners = align_corners
-
- def forward(self, x):
- """Forward pass.
-
- Args:
- x (tensor): input
-
- Returns:
- tensor: interpolated data
- """
-
- x = self.interp(
- x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
- )
-
- return x
-
-
-class ResidualConvUnit(nn.Module):
- """Residual convolution module.
- """
-
- def __init__(self, features):
- """Init.
-
- Args:
- features (int): number of features
- """
- super().__init__()
-
- self.conv1 = nn.Conv2d(
- features, features, kernel_size=3, stride=1, padding=1, bias=True
- )
-
- self.conv2 = nn.Conv2d(
- features, features, kernel_size=3, stride=1, padding=1, bias=True
- )
-
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x):
- """Forward pass.
-
- Args:
- x (tensor): input
-
- Returns:
- tensor: output
- """
- out = self.relu(x)
- out = self.conv1(out)
- out = self.relu(out)
- out = self.conv2(out)
-
- return out + x
-
-
-class FeatureFusionBlock(nn.Module):
- """Feature fusion block.
- """
-
- def __init__(self, features):
- """Init.
-
- Args:
- features (int): number of features
- """
- super(FeatureFusionBlock, self).__init__()
-
- self.resConfUnit1 = ResidualConvUnit(features)
- self.resConfUnit2 = ResidualConvUnit(features)
-
- def forward(self, *xs):
- """Forward pass.
-
- Returns:
- tensor: output
- """
- output = xs[0]
-
- if len(xs) == 2:
- output += self.resConfUnit1(xs[1])
-
- output = self.resConfUnit2(output)
-
- output = nn.functional.interpolate(
- output, scale_factor=2, mode="bilinear", align_corners=True
- )
-
- return output
-
-
-
-
-class ResidualConvUnit_custom(nn.Module):
- """Residual convolution module.
- """
-
- def __init__(self, features, activation, bn):
- """Init.
-
- Args:
- features (int): number of features
- """
- super().__init__()
-
- self.bn = bn
-
- self.groups=1
-
- self.conv1 = nn.Conv2d(
- features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
- )
-
- self.conv2 = nn.Conv2d(
- features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
- )
-
- if self.bn==True:
- self.bn1 = nn.BatchNorm2d(features)
- self.bn2 = nn.BatchNorm2d(features)
-
- self.activation = activation
-
- self.skip_add = nn.quantized.FloatFunctional()
-
- def forward(self, x):
- """Forward pass.
-
- Args:
- x (tensor): input
-
- Returns:
- tensor: output
- """
-
- out = self.activation(x)
- out = self.conv1(out)
- if self.bn==True:
- out = self.bn1(out)
-
- out = self.activation(out)
- out = self.conv2(out)
- if self.bn==True:
- out = self.bn2(out)
-
- if self.groups > 1:
- out = self.conv_merge(out)
-
- return self.skip_add.add(out, x)
-
- # return out + x
-
-
-class FeatureFusionBlock_custom(nn.Module):
- """Feature fusion block.
- """
-
- def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
- """Init.
-
- Args:
- features (int): number of features
- """
- super(FeatureFusionBlock_custom, self).__init__()
-
- self.deconv = deconv
- self.align_corners = align_corners
-
- self.groups=1
-
- self.expand = expand
- out_features = features
- if self.expand==True:
- out_features = features//2
-
- self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
-
- self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
- self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
-
- self.skip_add = nn.quantized.FloatFunctional()
-
- def forward(self, *xs):
- """Forward pass.
-
- Returns:
- tensor: output
- """
- output = xs[0]
-
- if len(xs) == 2:
- res = self.resConfUnit1(xs[1])
- output = self.skip_add.add(output, res)
- # output += res
-
- output = self.resConfUnit2(output)
-
- output = nn.functional.interpolate(
- output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
- )
-
- output = self.out_conv(output)
-
- return output
-
diff --git a/comfy/ldm/modules/midas/midas/dpt_depth.py b/comfy/ldm/modules/midas/midas/dpt_depth.py
deleted file mode 100644
index 4e9aab5d2..000000000
--- a/comfy/ldm/modules/midas/midas/dpt_depth.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .base_model import BaseModel
-from .blocks import (
- FeatureFusionBlock,
- FeatureFusionBlock_custom,
- Interpolate,
- _make_encoder,
- forward_vit,
-)
-
-
-def _make_fusion_block(features, use_bn):
- return FeatureFusionBlock_custom(
- features,
- nn.ReLU(False),
- deconv=False,
- bn=use_bn,
- expand=False,
- align_corners=True,
- )
-
-
-class DPT(BaseModel):
- def __init__(
- self,
- head,
- features=256,
- backbone="vitb_rn50_384",
- readout="project",
- channels_last=False,
- use_bn=False,
- ):
-
- super(DPT, self).__init__()
-
- self.channels_last = channels_last
-
- hooks = {
- "vitb_rn50_384": [0, 1, 8, 11],
- "vitb16_384": [2, 5, 8, 11],
- "vitl16_384": [5, 11, 17, 23],
- }
-
- # Instantiate backbone and reassemble blocks
- self.pretrained, self.scratch = _make_encoder(
- backbone,
- features,
- False, # Set to true of you want to train from scratch, uses ImageNet weights
- groups=1,
- expand=False,
- exportable=False,
- hooks=hooks[backbone],
- use_readout=readout,
- )
-
- self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
- self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
- self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
- self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
-
- self.scratch.output_conv = head
-
-
- def forward(self, x):
- if self.channels_last == True:
- x.contiguous(memory_format=torch.channels_last)
-
- layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
-
- layer_1_rn = self.scratch.layer1_rn(layer_1)
- layer_2_rn = self.scratch.layer2_rn(layer_2)
- layer_3_rn = self.scratch.layer3_rn(layer_3)
- layer_4_rn = self.scratch.layer4_rn(layer_4)
-
- path_4 = self.scratch.refinenet4(layer_4_rn)
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
-
- out = self.scratch.output_conv(path_1)
-
- return out
-
-
-class DPTDepthModel(DPT):
- def __init__(self, path=None, non_negative=True, **kwargs):
- features = kwargs["features"] if "features" in kwargs else 256
-
- head = nn.Sequential(
- nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
- Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
- nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
- nn.ReLU(True),
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
- nn.ReLU(True) if non_negative else nn.Identity(),
- nn.Identity(),
- )
-
- super().__init__(head, **kwargs)
-
- if path is not None:
- self.load(path)
-
- def forward(self, x):
- return super().forward(x).squeeze(dim=1)
-
diff --git a/comfy/ldm/modules/midas/midas/midas_net.py b/comfy/ldm/modules/midas/midas/midas_net.py
deleted file mode 100644
index 8a9549778..000000000
--- a/comfy/ldm/modules/midas/midas/midas_net.py
+++ /dev/null
@@ -1,76 +0,0 @@
-"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
-This file contains code that is adapted from
-https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
-"""
-import torch
-import torch.nn as nn
-
-from .base_model import BaseModel
-from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
-
-
-class MidasNet(BaseModel):
- """Network for monocular depth estimation.
- """
-
- def __init__(self, path=None, features=256, non_negative=True):
- """Init.
-
- Args:
- path (str, optional): Path to saved model. Defaults to None.
- features (int, optional): Number of features. Defaults to 256.
- backbone (str, optional): Backbone network for encoder. Defaults to resnet50
- """
- print("Loading weights: ", path)
-
- super(MidasNet, self).__init__()
-
- use_pretrained = False if path is None else True
-
- self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
-
- self.scratch.refinenet4 = FeatureFusionBlock(features)
- self.scratch.refinenet3 = FeatureFusionBlock(features)
- self.scratch.refinenet2 = FeatureFusionBlock(features)
- self.scratch.refinenet1 = FeatureFusionBlock(features)
-
- self.scratch.output_conv = nn.Sequential(
- nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
- Interpolate(scale_factor=2, mode="bilinear"),
- nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
- nn.ReLU(True),
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
- nn.ReLU(True) if non_negative else nn.Identity(),
- )
-
- if path:
- self.load(path)
-
- def forward(self, x):
- """Forward pass.
-
- Args:
- x (tensor): input data (image)
-
- Returns:
- tensor: depth
- """
-
- layer_1 = self.pretrained.layer1(x)
- layer_2 = self.pretrained.layer2(layer_1)
- layer_3 = self.pretrained.layer3(layer_2)
- layer_4 = self.pretrained.layer4(layer_3)
-
- layer_1_rn = self.scratch.layer1_rn(layer_1)
- layer_2_rn = self.scratch.layer2_rn(layer_2)
- layer_3_rn = self.scratch.layer3_rn(layer_3)
- layer_4_rn = self.scratch.layer4_rn(layer_4)
-
- path_4 = self.scratch.refinenet4(layer_4_rn)
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
-
- out = self.scratch.output_conv(path_1)
-
- return torch.squeeze(out, dim=1)
diff --git a/comfy/ldm/modules/midas/midas/midas_net_custom.py b/comfy/ldm/modules/midas/midas/midas_net_custom.py
deleted file mode 100644
index 50e4acb5e..000000000
--- a/comfy/ldm/modules/midas/midas/midas_net_custom.py
+++ /dev/null
@@ -1,128 +0,0 @@
-"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
-This file contains code that is adapted from
-https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
-"""
-import torch
-import torch.nn as nn
-
-from .base_model import BaseModel
-from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
-
-
-class MidasNet_small(BaseModel):
- """Network for monocular depth estimation.
- """
-
- def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
- blocks={'expand': True}):
- """Init.
-
- Args:
- path (str, optional): Path to saved model. Defaults to None.
- features (int, optional): Number of features. Defaults to 256.
- backbone (str, optional): Backbone network for encoder. Defaults to resnet50
- """
- print("Loading weights: ", path)
-
- super(MidasNet_small, self).__init__()
-
- use_pretrained = False if path else True
-
- self.channels_last = channels_last
- self.blocks = blocks
- self.backbone = backbone
-
- self.groups = 1
-
- features1=features
- features2=features
- features3=features
- features4=features
- self.expand = False
- if "expand" in self.blocks and self.blocks['expand'] == True:
- self.expand = True
- features1=features
- features2=features*2
- features3=features*4
- features4=features*8
-
- self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
-
- self.scratch.activation = nn.ReLU(False)
-
- self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
- self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
- self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
- self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
-
-
- self.scratch.output_conv = nn.Sequential(
- nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
- Interpolate(scale_factor=2, mode="bilinear"),
- nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
- self.scratch.activation,
- nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
- nn.ReLU(True) if non_negative else nn.Identity(),
- nn.Identity(),
- )
-
- if path:
- self.load(path)
-
-
- def forward(self, x):
- """Forward pass.
-
- Args:
- x (tensor): input data (image)
-
- Returns:
- tensor: depth
- """
- if self.channels_last==True:
- print("self.channels_last = ", self.channels_last)
- x.contiguous(memory_format=torch.channels_last)
-
-
- layer_1 = self.pretrained.layer1(x)
- layer_2 = self.pretrained.layer2(layer_1)
- layer_3 = self.pretrained.layer3(layer_2)
- layer_4 = self.pretrained.layer4(layer_3)
-
- layer_1_rn = self.scratch.layer1_rn(layer_1)
- layer_2_rn = self.scratch.layer2_rn(layer_2)
- layer_3_rn = self.scratch.layer3_rn(layer_3)
- layer_4_rn = self.scratch.layer4_rn(layer_4)
-
-
- path_4 = self.scratch.refinenet4(layer_4_rn)
- path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
- path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
- path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
-
- out = self.scratch.output_conv(path_1)
-
- return torch.squeeze(out, dim=1)
-
-
-
-def fuse_model(m):
- prev_previous_type = nn.Identity()
- prev_previous_name = ''
- previous_type = nn.Identity()
- previous_name = ''
- for name, module in m.named_modules():
- if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
- # print("FUSED ", prev_previous_name, previous_name, name)
- torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
- elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
- # print("FUSED ", prev_previous_name, previous_name)
- torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
- # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
- # print("FUSED ", previous_name, name)
- # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
-
- prev_previous_type = previous_type
- prev_previous_name = previous_name
- previous_type = type(module)
- previous_name = name
\ No newline at end of file
diff --git a/comfy/ldm/modules/midas/midas/transforms.py b/comfy/ldm/modules/midas/midas/transforms.py
deleted file mode 100644
index 350cbc116..000000000
--- a/comfy/ldm/modules/midas/midas/transforms.py
+++ /dev/null
@@ -1,234 +0,0 @@
-import numpy as np
-import cv2
-import math
-
-
-def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
- """Rezise the sample to ensure the given size. Keeps aspect ratio.
-
- Args:
- sample (dict): sample
- size (tuple): image size
-
- Returns:
- tuple: new size
- """
- shape = list(sample["disparity"].shape)
-
- if shape[0] >= size[0] and shape[1] >= size[1]:
- return sample
-
- scale = [0, 0]
- scale[0] = size[0] / shape[0]
- scale[1] = size[1] / shape[1]
-
- scale = max(scale)
-
- shape[0] = math.ceil(scale * shape[0])
- shape[1] = math.ceil(scale * shape[1])
-
- # resize
- sample["image"] = cv2.resize(
- sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
- )
-
- sample["disparity"] = cv2.resize(
- sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
- )
- sample["mask"] = cv2.resize(
- sample["mask"].astype(np.float32),
- tuple(shape[::-1]),
- interpolation=cv2.INTER_NEAREST,
- )
- sample["mask"] = sample["mask"].astype(bool)
-
- return tuple(shape)
-
-
-class Resize(object):
- """Resize sample to given size (width, height).
- """
-
- def __init__(
- self,
- width,
- height,
- resize_target=True,
- keep_aspect_ratio=False,
- ensure_multiple_of=1,
- resize_method="lower_bound",
- image_interpolation_method=cv2.INTER_AREA,
- ):
- """Init.
-
- Args:
- width (int): desired output width
- height (int): desired output height
- resize_target (bool, optional):
- True: Resize the full sample (image, mask, target).
- False: Resize image only.
- Defaults to True.
- keep_aspect_ratio (bool, optional):
- True: Keep the aspect ratio of the input sample.
- Output sample might not have the given width and height, and
- resize behaviour depends on the parameter 'resize_method'.
- Defaults to False.
- ensure_multiple_of (int, optional):
- Output width and height is constrained to be multiple of this parameter.
- Defaults to 1.
- resize_method (str, optional):
- "lower_bound": Output will be at least as large as the given size.
- "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
- "minimal": Scale as least as possible. (Output size might be smaller than given size.)
- Defaults to "lower_bound".
- """
- self.__width = width
- self.__height = height
-
- self.__resize_target = resize_target
- self.__keep_aspect_ratio = keep_aspect_ratio
- self.__multiple_of = ensure_multiple_of
- self.__resize_method = resize_method
- self.__image_interpolation_method = image_interpolation_method
-
- def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
- y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
-
- if max_val is not None and y > max_val:
- y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
-
- if y < min_val:
- y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
-
- return y
-
- def get_size(self, width, height):
- # determine new height and width
- scale_height = self.__height / height
- scale_width = self.__width / width
-
- if self.__keep_aspect_ratio:
- if self.__resize_method == "lower_bound":
- # scale such that output size is lower bound
- if scale_width > scale_height:
- # fit width
- scale_height = scale_width
- else:
- # fit height
- scale_width = scale_height
- elif self.__resize_method == "upper_bound":
- # scale such that output size is upper bound
- if scale_width < scale_height:
- # fit width
- scale_height = scale_width
- else:
- # fit height
- scale_width = scale_height
- elif self.__resize_method == "minimal":
- # scale as least as possbile
- if abs(1 - scale_width) < abs(1 - scale_height):
- # fit width
- scale_height = scale_width
- else:
- # fit height
- scale_width = scale_height
- else:
- raise ValueError(
- f"resize_method {self.__resize_method} not implemented"
- )
-
- if self.__resize_method == "lower_bound":
- new_height = self.constrain_to_multiple_of(
- scale_height * height, min_val=self.__height
- )
- new_width = self.constrain_to_multiple_of(
- scale_width * width, min_val=self.__width
- )
- elif self.__resize_method == "upper_bound":
- new_height = self.constrain_to_multiple_of(
- scale_height * height, max_val=self.__height
- )
- new_width = self.constrain_to_multiple_of(
- scale_width * width, max_val=self.__width
- )
- elif self.__resize_method == "minimal":
- new_height = self.constrain_to_multiple_of(scale_height * height)
- new_width = self.constrain_to_multiple_of(scale_width * width)
- else:
- raise ValueError(f"resize_method {self.__resize_method} not implemented")
-
- return (new_width, new_height)
-
- def __call__(self, sample):
- width, height = self.get_size(
- sample["image"].shape[1], sample["image"].shape[0]
- )
-
- # resize sample
- sample["image"] = cv2.resize(
- sample["image"],
- (width, height),
- interpolation=self.__image_interpolation_method,
- )
-
- if self.__resize_target:
- if "disparity" in sample:
- sample["disparity"] = cv2.resize(
- sample["disparity"],
- (width, height),
- interpolation=cv2.INTER_NEAREST,
- )
-
- if "depth" in sample:
- sample["depth"] = cv2.resize(
- sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
- )
-
- sample["mask"] = cv2.resize(
- sample["mask"].astype(np.float32),
- (width, height),
- interpolation=cv2.INTER_NEAREST,
- )
- sample["mask"] = sample["mask"].astype(bool)
-
- return sample
-
-
-class NormalizeImage(object):
- """Normlize image by given mean and std.
- """
-
- def __init__(self, mean, std):
- self.__mean = mean
- self.__std = std
-
- def __call__(self, sample):
- sample["image"] = (sample["image"] - self.__mean) / self.__std
-
- return sample
-
-
-class PrepareForNet(object):
- """Prepare sample for usage as network input.
- """
-
- def __init__(self):
- pass
-
- def __call__(self, sample):
- image = np.transpose(sample["image"], (2, 0, 1))
- sample["image"] = np.ascontiguousarray(image).astype(np.float32)
-
- if "mask" in sample:
- sample["mask"] = sample["mask"].astype(np.float32)
- sample["mask"] = np.ascontiguousarray(sample["mask"])
-
- if "disparity" in sample:
- disparity = sample["disparity"].astype(np.float32)
- sample["disparity"] = np.ascontiguousarray(disparity)
-
- if "depth" in sample:
- depth = sample["depth"].astype(np.float32)
- sample["depth"] = np.ascontiguousarray(depth)
-
- return sample
diff --git a/comfy/ldm/modules/midas/midas/vit.py b/comfy/ldm/modules/midas/midas/vit.py
deleted file mode 100644
index ea46b1be8..000000000
--- a/comfy/ldm/modules/midas/midas/vit.py
+++ /dev/null
@@ -1,491 +0,0 @@
-import torch
-import torch.nn as nn
-import timm
-import types
-import math
-import torch.nn.functional as F
-
-
-class Slice(nn.Module):
- def __init__(self, start_index=1):
- super(Slice, self).__init__()
- self.start_index = start_index
-
- def forward(self, x):
- return x[:, self.start_index :]
-
-
-class AddReadout(nn.Module):
- def __init__(self, start_index=1):
- super(AddReadout, self).__init__()
- self.start_index = start_index
-
- def forward(self, x):
- if self.start_index == 2:
- readout = (x[:, 0] + x[:, 1]) / 2
- else:
- readout = x[:, 0]
- return x[:, self.start_index :] + readout.unsqueeze(1)
-
-
-class ProjectReadout(nn.Module):
- def __init__(self, in_features, start_index=1):
- super(ProjectReadout, self).__init__()
- self.start_index = start_index
-
- self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
-
- def forward(self, x):
- readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
- features = torch.cat((x[:, self.start_index :], readout), -1)
-
- return self.project(features)
-
-
-class Transpose(nn.Module):
- def __init__(self, dim0, dim1):
- super(Transpose, self).__init__()
- self.dim0 = dim0
- self.dim1 = dim1
-
- def forward(self, x):
- x = x.transpose(self.dim0, self.dim1)
- return x
-
-
-def forward_vit(pretrained, x):
- b, c, h, w = x.shape
-
- glob = pretrained.model.forward_flex(x)
-
- layer_1 = pretrained.activations["1"]
- layer_2 = pretrained.activations["2"]
- layer_3 = pretrained.activations["3"]
- layer_4 = pretrained.activations["4"]
-
- layer_1 = pretrained.act_postprocess1[0:2](layer_1)
- layer_2 = pretrained.act_postprocess2[0:2](layer_2)
- layer_3 = pretrained.act_postprocess3[0:2](layer_3)
- layer_4 = pretrained.act_postprocess4[0:2](layer_4)
-
- unflatten = nn.Sequential(
- nn.Unflatten(
- 2,
- torch.Size(
- [
- h // pretrained.model.patch_size[1],
- w // pretrained.model.patch_size[0],
- ]
- ),
- )
- )
-
- if layer_1.ndim == 3:
- layer_1 = unflatten(layer_1)
- if layer_2.ndim == 3:
- layer_2 = unflatten(layer_2)
- if layer_3.ndim == 3:
- layer_3 = unflatten(layer_3)
- if layer_4.ndim == 3:
- layer_4 = unflatten(layer_4)
-
- layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
- layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
- layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
- layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
-
- return layer_1, layer_2, layer_3, layer_4
-
-
-def _resize_pos_embed(self, posemb, gs_h, gs_w):
- posemb_tok, posemb_grid = (
- posemb[:, : self.start_index],
- posemb[0, self.start_index :],
- )
-
- gs_old = int(math.sqrt(len(posemb_grid)))
-
- posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
- posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
- posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
-
- posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
-
- return posemb
-
-
-def forward_flex(self, x):
- b, c, h, w = x.shape
-
- pos_embed = self._resize_pos_embed(
- self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
- )
-
- B = x.shape[0]
-
- if hasattr(self.patch_embed, "backbone"):
- x = self.patch_embed.backbone(x)
- if isinstance(x, (list, tuple)):
- x = x[-1] # last feature if backbone outputs list/tuple of features
-
- x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
-
- if getattr(self, "dist_token", None) is not None:
- cls_tokens = self.cls_token.expand(
- B, -1, -1
- ) # stole cls_tokens impl from Phil Wang, thanks
- dist_token = self.dist_token.expand(B, -1, -1)
- x = torch.cat((cls_tokens, dist_token, x), dim=1)
- else:
- cls_tokens = self.cls_token.expand(
- B, -1, -1
- ) # stole cls_tokens impl from Phil Wang, thanks
- x = torch.cat((cls_tokens, x), dim=1)
-
- x = x + pos_embed
- x = self.pos_drop(x)
-
- for blk in self.blocks:
- x = blk(x)
-
- x = self.norm(x)
-
- return x
-
-
-activations = {}
-
-
-def get_activation(name):
- def hook(model, input, output):
- activations[name] = output
-
- return hook
-
-
-def get_readout_oper(vit_features, features, use_readout, start_index=1):
- if use_readout == "ignore":
- readout_oper = [Slice(start_index)] * len(features)
- elif use_readout == "add":
- readout_oper = [AddReadout(start_index)] * len(features)
- elif use_readout == "project":
- readout_oper = [
- ProjectReadout(vit_features, start_index) for out_feat in features
- ]
- else:
- assert (
- False
- ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
-
- return readout_oper
-
-
-def _make_vit_b16_backbone(
- model,
- features=[96, 192, 384, 768],
- size=[384, 384],
- hooks=[2, 5, 8, 11],
- vit_features=768,
- use_readout="ignore",
- start_index=1,
-):
- pretrained = nn.Module()
-
- pretrained.model = model
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
-
- pretrained.activations = activations
-
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
-
- # 32, 48, 136, 384
- pretrained.act_postprocess1 = nn.Sequential(
- readout_oper[0],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[0],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- nn.ConvTranspose2d(
- in_channels=features[0],
- out_channels=features[0],
- kernel_size=4,
- stride=4,
- padding=0,
- bias=True,
- dilation=1,
- groups=1,
- ),
- )
-
- pretrained.act_postprocess2 = nn.Sequential(
- readout_oper[1],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[1],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- nn.ConvTranspose2d(
- in_channels=features[1],
- out_channels=features[1],
- kernel_size=2,
- stride=2,
- padding=0,
- bias=True,
- dilation=1,
- groups=1,
- ),
- )
-
- pretrained.act_postprocess3 = nn.Sequential(
- readout_oper[2],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[2],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- )
-
- pretrained.act_postprocess4 = nn.Sequential(
- readout_oper[3],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[3],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- nn.Conv2d(
- in_channels=features[3],
- out_channels=features[3],
- kernel_size=3,
- stride=2,
- padding=1,
- ),
- )
-
- pretrained.model.start_index = start_index
- pretrained.model.patch_size = [16, 16]
-
- # We inject this function into the VisionTransformer instances so that
- # we can use it with interpolated position embeddings without modifying the library source.
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
- pretrained.model._resize_pos_embed = types.MethodType(
- _resize_pos_embed, pretrained.model
- )
-
- return pretrained
-
-
-def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
- model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
-
- hooks = [5, 11, 17, 23] if hooks == None else hooks
- return _make_vit_b16_backbone(
- model,
- features=[256, 512, 1024, 1024],
- hooks=hooks,
- vit_features=1024,
- use_readout=use_readout,
- )
-
-
-def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
- model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
-
- hooks = [2, 5, 8, 11] if hooks == None else hooks
- return _make_vit_b16_backbone(
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
- )
-
-
-def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
- model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
-
- hooks = [2, 5, 8, 11] if hooks == None else hooks
- return _make_vit_b16_backbone(
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
- )
-
-
-def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
- model = timm.create_model(
- "vit_deit_base_distilled_patch16_384", pretrained=pretrained
- )
-
- hooks = [2, 5, 8, 11] if hooks == None else hooks
- return _make_vit_b16_backbone(
- model,
- features=[96, 192, 384, 768],
- hooks=hooks,
- use_readout=use_readout,
- start_index=2,
- )
-
-
-def _make_vit_b_rn50_backbone(
- model,
- features=[256, 512, 768, 768],
- size=[384, 384],
- hooks=[0, 1, 8, 11],
- vit_features=768,
- use_vit_only=False,
- use_readout="ignore",
- start_index=1,
-):
- pretrained = nn.Module()
-
- pretrained.model = model
-
- if use_vit_only == True:
- pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
- pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
- else:
- pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
- get_activation("1")
- )
- pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
- get_activation("2")
- )
-
- pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
- pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
-
- pretrained.activations = activations
-
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
-
- if use_vit_only == True:
- pretrained.act_postprocess1 = nn.Sequential(
- readout_oper[0],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[0],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- nn.ConvTranspose2d(
- in_channels=features[0],
- out_channels=features[0],
- kernel_size=4,
- stride=4,
- padding=0,
- bias=True,
- dilation=1,
- groups=1,
- ),
- )
-
- pretrained.act_postprocess2 = nn.Sequential(
- readout_oper[1],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[1],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- nn.ConvTranspose2d(
- in_channels=features[1],
- out_channels=features[1],
- kernel_size=2,
- stride=2,
- padding=0,
- bias=True,
- dilation=1,
- groups=1,
- ),
- )
- else:
- pretrained.act_postprocess1 = nn.Sequential(
- nn.Identity(), nn.Identity(), nn.Identity()
- )
- pretrained.act_postprocess2 = nn.Sequential(
- nn.Identity(), nn.Identity(), nn.Identity()
- )
-
- pretrained.act_postprocess3 = nn.Sequential(
- readout_oper[2],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[2],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- )
-
- pretrained.act_postprocess4 = nn.Sequential(
- readout_oper[3],
- Transpose(1, 2),
- nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
- nn.Conv2d(
- in_channels=vit_features,
- out_channels=features[3],
- kernel_size=1,
- stride=1,
- padding=0,
- ),
- nn.Conv2d(
- in_channels=features[3],
- out_channels=features[3],
- kernel_size=3,
- stride=2,
- padding=1,
- ),
- )
-
- pretrained.model.start_index = start_index
- pretrained.model.patch_size = [16, 16]
-
- # We inject this function into the VisionTransformer instances so that
- # we can use it with interpolated position embeddings without modifying the library source.
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
-
- # We inject this function into the VisionTransformer instances so that
- # we can use it with interpolated position embeddings without modifying the library source.
- pretrained.model._resize_pos_embed = types.MethodType(
- _resize_pos_embed, pretrained.model
- )
-
- return pretrained
-
-
-def _make_pretrained_vitb_rn50_384(
- pretrained, use_readout="ignore", hooks=None, use_vit_only=False
-):
- model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
-
- hooks = [0, 1, 8, 11] if hooks == None else hooks
- return _make_vit_b_rn50_backbone(
- model,
- features=[256, 512, 768, 768],
- size=[384, 384],
- hooks=hooks,
- use_vit_only=use_vit_only,
- use_readout=use_readout,
- )
diff --git a/comfy/ldm/modules/midas/utils.py b/comfy/ldm/modules/midas/utils.py
deleted file mode 100644
index 9a9d3b5b6..000000000
--- a/comfy/ldm/modules/midas/utils.py
+++ /dev/null
@@ -1,189 +0,0 @@
-"""Utils for monoDepth."""
-import sys
-import re
-import numpy as np
-import cv2
-import torch
-
-
-def read_pfm(path):
- """Read pfm file.
-
- Args:
- path (str): path to file
-
- Returns:
- tuple: (data, scale)
- """
- with open(path, "rb") as file:
-
- color = None
- width = None
- height = None
- scale = None
- endian = None
-
- header = file.readline().rstrip()
- if header.decode("ascii") == "PF":
- color = True
- elif header.decode("ascii") == "Pf":
- color = False
- else:
- raise Exception("Not a PFM file: " + path)
-
- dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
- if dim_match:
- width, height = list(map(int, dim_match.groups()))
- else:
- raise Exception("Malformed PFM header.")
-
- scale = float(file.readline().decode("ascii").rstrip())
- if scale < 0:
- # little-endian
- endian = "<"
- scale = -scale
- else:
- # big-endian
- endian = ">"
-
- data = np.fromfile(file, endian + "f")
- shape = (height, width, 3) if color else (height, width)
-
- data = np.reshape(data, shape)
- data = np.flipud(data)
-
- return data, scale
-
-
-def write_pfm(path, image, scale=1):
- """Write pfm file.
-
- Args:
- path (str): pathto file
- image (array): data
- scale (int, optional): Scale. Defaults to 1.
- """
-
- with open(path, "wb") as file:
- color = None
-
- if image.dtype.name != "float32":
- raise Exception("Image dtype must be float32.")
-
- image = np.flipud(image)
-
- if len(image.shape) == 3 and image.shape[2] == 3: # color image
- color = True
- elif (
- len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
- ): # greyscale
- color = False
- else:
- raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
-
- file.write("PF\n" if color else "Pf\n".encode())
- file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
-
- endian = image.dtype.byteorder
-
- if endian == "<" or endian == "=" and sys.byteorder == "little":
- scale = -scale
-
- file.write("%f\n".encode() % scale)
-
- image.tofile(file)
-
-
-def read_image(path):
- """Read image and output RGB image (0-1).
-
- Args:
- path (str): path to file
-
- Returns:
- array: RGB image (0-1)
- """
- img = cv2.imread(path)
-
- if img.ndim == 2:
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
-
- return img
-
-
-def resize_image(img):
- """Resize image and make it fit for network.
-
- Args:
- img (array): image
-
- Returns:
- tensor: data ready for network
- """
- height_orig = img.shape[0]
- width_orig = img.shape[1]
-
- if width_orig > height_orig:
- scale = width_orig / 384
- else:
- scale = height_orig / 384
-
- height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
- width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
-
- img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
-
- img_resized = (
- torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
- )
- img_resized = img_resized.unsqueeze(0)
-
- return img_resized
-
-
-def resize_depth(depth, width, height):
- """Resize depth map and bring to CPU (numpy).
-
- Args:
- depth (tensor): depth
- width (int): image width
- height (int): image height
-
- Returns:
- array: processed depth
- """
- depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
-
- depth_resized = cv2.resize(
- depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
- )
-
- return depth_resized
-
-def write_depth(path, depth, bits=1):
- """Write depth map to pfm and png file.
-
- Args:
- path (str): filepath without extension
- depth (array): depth
- """
- write_pfm(path + ".pfm", depth.astype(np.float32))
-
- depth_min = depth.min()
- depth_max = depth.max()
-
- max_val = (2**(8*bits))-1
-
- if depth_max - depth_min > np.finfo("float").eps:
- out = max_val * (depth - depth_min) / (depth_max - depth_min)
- else:
- out = np.zeros(depth.shape, dtype=depth.type)
-
- if bits == 1:
- cv2.imwrite(path + ".png", out.astype("uint8"))
- elif bits == 2:
- cv2.imwrite(path + ".png", out.astype("uint16"))
-
- return
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 7370c19fd..9adea9a5d 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -60,6 +60,37 @@ class SD21UNCLIP(BaseModel):
super().__init__(unet_config, v_prediction)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
+ def encode_adm(self, **kwargs):
+ unclip_conditioning = kwargs.get("unclip_conditioning", None)
+ device = kwargs["device"]
+
+ if unclip_conditioning is not None:
+ adm_inputs = []
+ weights = []
+ noise_aug = []
+ for unclip_cond in unclip_conditioning:
+ adm_cond = unclip_cond["clip_vision_output"].image_embeds
+ weight = unclip_cond["strength"]
+ noise_augment = unclip_cond["noise_augmentation"]
+ noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
+ c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
+ adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
+ weights.append(weight)
+ noise_aug.append(noise_augment)
+ adm_inputs.append(adm_out)
+
+ if len(noise_aug) > 1:
+ adm_out = torch.stack(adm_inputs).sum(0)
+ #TODO: add a way to control this
+ noise_augment = 0.05
+ noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment)
+ c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
+ adm_out = torch.cat((c_adm, noise_level_emb), 1)
+ else:
+ adm_out = torch.zeros((1, self.adm_channels))
+
+ return adm_out
+
class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 1a8a1be17..d64dce187 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -151,7 +151,7 @@ if args.lowvram:
lowvram_available = True
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
-elif args.highvram:
+elif args.highvram or args.gpu_only:
vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False
@@ -307,6 +307,12 @@ def unload_if_low_vram(model):
return model.cpu()
return model
+def text_encoder_device():
+ if args.gpu_only:
+ return get_torch_device()
+ else:
+ return torch.device("cpu")
+
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
diff --git a/comfy/ops.py b/comfy/ops.py
new file mode 100644
index 000000000..2e72030bd
--- /dev/null
+++ b/comfy/ops.py
@@ -0,0 +1,32 @@
+import torch
+from contextlib import contextmanager
+
+class Linear(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
+ device=None, dtype=None) -> None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
+ if bias:
+ self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return torch.nn.functional.linear(input, self.weight, self.bias)
+
+class Conv2d(torch.nn.Conv2d):
+ def reset_parameters(self):
+ return None
+
+
+@contextmanager
+def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
+ old_torch_nn_linear = torch.nn.Linear
+ torch.nn.Linear = Linear
+ try:
+ yield
+ finally:
+ torch.nn.Linear = old_torch_nn_linear
diff --git a/comfy/samplers.py b/comfy/samplers.py
index a33d150d0..dffd7fe7c 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -273,7 +273,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
if "sampler_cfg_function" in model_options:
- return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
+ args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
+ return model_options["sampler_cfg_function"](args)
else:
return uncond + (cond - uncond) * cond_scale
@@ -460,42 +461,18 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond[temp[1]] = [o[0], n]
-def encode_adm(conds, batch_size, device, noise_augmentor=None):
+def encode_adm(model, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
adm_out = None
- if noise_augmentor is not None:
- if 'adm' in x[1]:
- adm_inputs = []
- weights = []
- noise_aug = []
- adm_in = x[1]["adm"]
- for adm_c in adm_in:
- adm_cond = adm_c[0].image_embeds
- weight = adm_c[1]
- noise_augment = adm_c[2]
- noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
- c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
- adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
- weights.append(weight)
- noise_aug.append(noise_augment)
- adm_inputs.append(adm_out)
-
- if len(noise_aug) > 1:
- adm_out = torch.stack(adm_inputs).sum(0)
- #TODO: add a way to control this
- noise_augment = 0.05
- noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
- c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
- adm_out = torch.cat((c_adm, noise_level_emb), 1)
- else:
- adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
+ if 'adm' in x[1]:
+ adm_out = x[1]["adm"]
else:
- if 'adm' in x[1]:
- adm_out = x[1]["adm"].to(device)
+ params = x[1].copy()
+ adm_out = model.encode_adm(device=device, **params)
if adm_out is not None:
x[1] = x[1].copy()
- x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
+ x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device)
return conds
@@ -603,11 +580,8 @@ class KSampler:
precision_scope = contextlib.nullcontext
if self.model.is_adm():
- noise_augmentor = None
- if hasattr(self.model, 'noise_augmentor'): #unclip
- noise_augmentor = self.model.noise_augmentor
- positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor)
- negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor)
+ positive = encode_adm(self.model, positive, noise.shape[0], self.device)
+ negative = encode_adm(self.model, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
diff --git a/comfy/sd.py b/comfy/sd.py
index 3747f53b8..e6cda5131 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -1,6 +1,7 @@
import torch
import contextlib
import copy
+import inspect
from . import sd1_clip
from . import sd2_clip
@@ -85,7 +86,7 @@ LORA_UNET_MAP_RESNET = {
}
def load_lora(path, to_load):
- lora = utils.load_torch_file(path)
+ lora = utils.load_torch_file(path, safe_load=True)
patch_dict = {}
loaded_keys = set()
for x in to_load:
@@ -313,8 +314,10 @@ class ModelPatcher:
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def set_model_sampler_cfg_function(self, sampler_cfg_function):
- self.model_options["sampler_cfg_function"] = sampler_cfg_function
-
+ if len(inspect.signature(sampler_cfg_function).parameters) == 3:
+ self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
+ else:
+ self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
@@ -328,6 +331,9 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
+ def set_model_attn2_output_patch(self, patch):
+ self.set_model_patch(patch, "attn2_output_patch")
+
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
@@ -464,7 +470,11 @@ class CLIP:
clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer
+ self.device = model_management.text_encoder_device()
+ params["device"] = self.device
self.cond_stage_model = clip(**(params))
+ self.cond_stage_model = self.cond_stage_model.to(self.device)
+
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model)
self.layer_idx = None
@@ -544,6 +554,19 @@ class VAE:
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
+ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
+ steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
+ steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
+ steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
+ pbar = utils.ProgressBar(steps)
+
+ encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor
+ samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
+ samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
+ samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
+ samples /= 3.0
+ return samples
+
def decode(self, samples_in):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
@@ -574,28 +597,29 @@ class VAE:
def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
- pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
- samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
+ pixel_samples = pixel_samples.movedim(-1,1)
+ try:
+ free_memory = model_management.get_free_memory(self.device)
+ batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
+ batch_number = max(1, batch_number)
+ samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
+ for x in range(0, pixel_samples.shape[0], batch_number):
+ pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
+ samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor
+
+ except model_management.OOM_EXCEPTION as e:
+ print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
+ samples = self.encode_tiled_(pixel_samples)
+
self.first_stage_model = self.first_stage_model.cpu()
- samples = samples.cpu()
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
- pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
-
- steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
- steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
- steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
- pbar = utils.ProgressBar(steps)
-
- samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
- samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
- samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
- samples /= 3.0
+ pixel_samples = pixel_samples.movedim(-1,1)
+ samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.cpu()
- samples = samples.cpu()
return samples
def broadcast_image_to(tensor, target_batch_size, batched_number):
@@ -708,7 +732,7 @@ class ControlNet:
return out
def load_controlnet(ckpt_path, model=None):
- controlnet_data = utils.load_torch_file(ckpt_path)
+ controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False
sd2 = False
@@ -910,7 +934,7 @@ class StyleModel:
def load_style_model(ckpt_path):
- model_data = utils.load_torch_file(ckpt_path)
+ model_data = utils.load_torch_file(ckpt_path, safe_load=True)
keys = model_data.keys()
if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
@@ -921,7 +945,7 @@ def load_style_model(ckpt_path):
def load_clip(ckpt_path, embedding_directory=None):
- clip_data = utils.load_torch_file(ckpt_path)
+ clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
@@ -932,7 +956,7 @@ def load_clip(ckpt_path, embedding_directory=None):
return clip
def load_gligen(ckpt_path):
- data = utils.load_torch_file(ckpt_path)
+ data = utils.load_torch_file(ckpt_path, safe_load=True)
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
@@ -1097,7 +1121,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
- model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
unclip_model = False
inpaint_model = False
@@ -1107,11 +1130,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
unclip_model = True
- model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
- model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
inpaint_model = True
else:
sd_config["conditioning_key"] = "crossattn"
@@ -1143,7 +1164,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
- if fp16:
- model = model.half()
-
return (ModelPatcher(model), clip, vae, clipvision)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 91fb4ff27..fa6d22dcb 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -1,6 +1,7 @@
import os
-from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
+from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
+import comfy.ops
import torch
import traceback
import zipfile
@@ -19,7 +20,7 @@ class ClipTokenWeightEncoder:
output += [z]
if (len(output) == 0):
return self.encode(self.empty_tokens)
- return torch.cat(output, dim=-2)
+ return torch.cat(output, dim=-2).cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
@@ -38,7 +39,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
- self.transformer = CLIPTextModel(config)
+ with comfy.ops.use_comfy_ops():
+ with modeling_utils.no_init_weights():
+ self.transformer = CLIPTextModel(config)
self.device = device
self.max_length = max_length
diff --git a/comfy/utils.py b/comfy/utils.py
index 585ebda51..401eb8038 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -1,6 +1,7 @@
import torch
import math
import struct
+import comfy.checkpoint_pickle
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
@@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False):
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else:
- pl_sd = torch.load(ckpt, map_location="cpu")
+ pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py
index c19b5e4c7..d16c49aeb 100644
--- a/comfy_extras/nodes_hypernetwork.py
+++ b/comfy_extras/nodes_hypernetwork.py
@@ -68,7 +68,7 @@ def load_hypernetwork_patch(path, strength):
def __init__(self, hypernet, strength):
self.hypernet = hypernet
self.strength = strength
- def __call__(self, current_index, q, k, v):
+ def __call__(self, q, k, v, extra_options):
dim = k.shape[-1]
if dim in self.hypernet:
hn = self.hypernet[dim]
diff --git a/execution.py b/execution.py
index 218a84c36..f93de8465 100644
--- a/execution.py
+++ b/execution.py
@@ -310,7 +310,6 @@ class PromptExecutor:
else:
self.server.client_id = None
- execution_start_time = time.perf_counter()
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
@@ -358,12 +357,7 @@ class PromptExecutor:
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
- if self.server.client_id is not None:
- self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
- print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
- gc.collect()
- comfy.model_management.soft_empty_cache()
def validate_inputs(prompt, item, validated):
@@ -728,9 +722,14 @@ class PromptQueue:
return True
return False
- def get_history(self):
+ def get_history(self, prompt_id=None):
with self.mutex:
- return copy.deepcopy(self.history)
+ if prompt_id is None:
+ return copy.deepcopy(self.history)
+ elif prompt_id in self.history:
+ return {prompt_id: copy.deepcopy(self.history[prompt_id])}
+ else:
+ return {}
def wipe_history(self):
with self.mutex:
diff --git a/main.py b/main.py
index 8293c06fc..22425d2aa 100644
--- a/main.py
+++ b/main.py
@@ -3,6 +3,8 @@ import itertools
import os
import shutil
import threading
+import gc
+import time
from comfy.cli_args import args
import comfy.utils
@@ -28,15 +30,22 @@ import folder_paths
import server
from server import BinaryEventTypes
from nodes import init_custom_nodes
-
+import comfy.model_management
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
item, item_id = q.get()
- e.execute(item[2], item[1], item[3], item[4])
+ execution_start_time = time.perf_counter()
+ prompt_id = item[1]
+ e.execute(item[2], prompt_id, item[3], item[4])
q.task_done(item_id, e.outputs_ui)
+ if server.client_id is not None:
+ server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
+ print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
+ gc.collect()
+ comfy.model_management.soft_empty_cache()
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
diff --git a/nodes.py b/nodes.py
index 45627a91e..a5949a408 100644
--- a/nodes.py
+++ b/nodes.py
@@ -626,11 +626,11 @@ class unCLIPConditioning:
c = []
for t in conditioning:
o = t[1].copy()
- x = (clip_vision_output, strength, noise_augmentation)
- if "adm" in o:
- o["adm"] = o["adm"][:] + [x]
+ x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
+ if "unclip_conditioning" in o:
+ o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
else:
- o["adm"] = [x]
+ o["unclip_conditioning"] = [x]
n = [t[0], o]
c.append(n)
return (c, )
@@ -759,7 +759,7 @@ class RepeatLatentBatch:
return (s,)
class LatentUpscale:
- upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
crop_methods = ["disabled", "center"]
@classmethod
@@ -779,7 +779,7 @@ class LatentUpscale:
return (s,)
class LatentUpscaleBy:
- upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
@classmethod
def INPUT_TYPES(s):
@@ -1175,7 +1175,7 @@ class LoadImageMask:
return True
class ImageScale:
- upscale_methods = ["nearest-exact", "bilinear", "area"]
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
crop_methods = ["disabled", "center"]
@classmethod
@@ -1195,6 +1195,26 @@ class ImageScale:
s = s.movedim(1,-1)
return (s,)
+class ImageScaleBy:
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
+ "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "upscale"
+
+ CATEGORY = "image/upscaling"
+
+ def upscale(self, image, upscale_method, scale_by):
+ samples = image.movedim(-1,1)
+ width = round(samples.shape[3] * scale_by)
+ height = round(samples.shape[2] * scale_by)
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
+ s = s.movedim(1,-1)
+ return (s,)
+
class ImageInvert:
@classmethod
@@ -1293,6 +1313,7 @@ NODE_CLASS_MAPPINGS = {
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"ImageScale": ImageScale,
+ "ImageScaleBy": ImageScaleBy,
"ImageInvert": ImageInvert,
"ImagePadForOutpaint": ImagePadForOutpaint,
"ConditioningAverage ": ConditioningAverage ,
@@ -1374,6 +1395,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadImage": "Load Image",
"LoadImageMask": "Load Image (as Mask)",
"ImageScale": "Upscale Image",
+ "ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
"ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting",
diff --git a/requirements.txt b/requirements.txt
index 0527b31df..d632edf79 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,10 +2,11 @@ torch
torchdiffeq
torchsde
einops
-open-clip-torch
transformers>=4.25.1
safetensors>=0.3.0
-pytorch_lightning
aiohttp
accelerate
pyyaml
+Pillow
+scipy
+tqdm
diff --git a/script_examples/websockets_api_example.py b/script_examples/websockets_api_example.py
new file mode 100644
index 000000000..57a6cbd9b
--- /dev/null
+++ b/script_examples/websockets_api_example.py
@@ -0,0 +1,164 @@
+#This is an example that uses the websockets api to know when a prompt execution is done
+#Once the prompt execution is done it downloads the images using the /history endpoint
+
+import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
+import uuid
+import json
+import urllib.request
+import urllib.parse
+
+server_address = "127.0.0.1:8188"
+client_id = str(uuid.uuid4())
+
+def queue_prompt(prompt):
+ p = {"prompt": prompt, "client_id": client_id}
+ data = json.dumps(p).encode('utf-8')
+ req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
+ return json.loads(urllib.request.urlopen(req).read())
+
+def get_image(filename, subfolder, folder_type):
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
+ url_values = urllib.parse.urlencode(data)
+ with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
+ return response.read()
+
+def get_history(prompt_id):
+ with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
+ return json.loads(response.read())
+
+def get_images(ws, prompt):
+ prompt_id = queue_prompt(prompt)['prompt_id']
+ output_images = {}
+ while True:
+ out = ws.recv()
+ if isinstance(out, str):
+ message = json.loads(out)
+ if message['type'] == 'executing':
+ data = message['data']
+ if data['node'] is None and data['prompt_id'] == prompt_id:
+ break #Execution is done
+ else:
+ continue #previews are binary data
+
+ history = get_history(prompt_id)[prompt_id]
+ for o in history['outputs']:
+ for node_id in history['outputs']:
+ node_output = history['outputs'][node_id]
+ if 'images' in node_output:
+ images_output = []
+ for image in node_output['images']:
+ image_data = get_image(image['filename'], image['subfolder'], image['type'])
+ images_output.append(image_data)
+ output_images[node_id] = images_output
+
+ return output_images
+
+prompt_text = """
+{
+ "3": {
+ "class_type": "KSampler",
+ "inputs": {
+ "cfg": 8,
+ "denoise": 1,
+ "latent_image": [
+ "5",
+ 0
+ ],
+ "model": [
+ "4",
+ 0
+ ],
+ "negative": [
+ "7",
+ 0
+ ],
+ "positive": [
+ "6",
+ 0
+ ],
+ "sampler_name": "euler",
+ "scheduler": "normal",
+ "seed": 8566257,
+ "steps": 20
+ }
+ },
+ "4": {
+ "class_type": "CheckpointLoaderSimple",
+ "inputs": {
+ "ckpt_name": "v1-5-pruned-emaonly.ckpt"
+ }
+ },
+ "5": {
+ "class_type": "EmptyLatentImage",
+ "inputs": {
+ "batch_size": 1,
+ "height": 512,
+ "width": 512
+ }
+ },
+ "6": {
+ "class_type": "CLIPTextEncode",
+ "inputs": {
+ "clip": [
+ "4",
+ 1
+ ],
+ "text": "masterpiece best quality girl"
+ }
+ },
+ "7": {
+ "class_type": "CLIPTextEncode",
+ "inputs": {
+ "clip": [
+ "4",
+ 1
+ ],
+ "text": "bad hands"
+ }
+ },
+ "8": {
+ "class_type": "VAEDecode",
+ "inputs": {
+ "samples": [
+ "3",
+ 0
+ ],
+ "vae": [
+ "4",
+ 2
+ ]
+ }
+ },
+ "9": {
+ "class_type": "SaveImage",
+ "inputs": {
+ "filename_prefix": "ComfyUI",
+ "images": [
+ "8",
+ 0
+ ]
+ }
+ }
+}
+"""
+
+prompt = json.loads(prompt_text)
+#set the text prompt for our positive CLIPTextEncode
+prompt["6"]["inputs"]["text"] = "masterpiece best quality man"
+
+#set the seed for our KSampler node
+prompt["3"]["inputs"]["seed"] = 5
+
+ws = websocket.WebSocket()
+ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
+images = get_images(ws, prompt)
+
+#Commented out code to display the output images:
+
+# for node_id in images:
+# for image_data in images[node_id]:
+# from PIL import Image
+# import io
+# image = Image.open(io.BytesIO(image_data))
+# image.show()
+
diff --git a/server.py b/server.py
index 174d38af1..f385cefb8 100644
--- a/server.py
+++ b/server.py
@@ -30,6 +30,11 @@ import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
+async def send_socket_catch_exception(function, message):
+ try:
+ await function(message)
+ except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
+ print("send error:", err)
@web.middleware
async def cache_control(request: web.Request, handler):
@@ -372,6 +377,11 @@ class PromptServer():
async def get_history(request):
return web.json_response(self.prompt_queue.get_history())
+ @routes.get("/history/{prompt_id}")
+ async def get_history(request):
+ prompt_id = request.match_info.get("prompt_id", None)
+ return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
+
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
@@ -482,18 +492,18 @@ class PromptServer():
if sid is None:
for ws in self.sockets.values():
- await ws.send_bytes(message)
+ await send_socket_catch_exception(ws.send_bytes, message)
elif sid in self.sockets:
- await self.sockets[sid].send_bytes(message)
+ await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
async def send_json(self, event, data, sid=None):
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
- await ws.send_json(message)
+ await send_socket_catch_exception(ws.send_json, message)
elif sid in self.sockets:
- await self.sockets[sid].send_json(message)
+ await send_socket_catch_exception(self.sockets[sid].send_json, message)
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(
diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js
index 84c2a3d10..9836143d3 100644
--- a/web/extensions/core/colorPalette.js
+++ b/web/extensions/core/colorPalette.js
@@ -1,6 +1,5 @@
-import { app } from "/scripts/app.js";
-import { $el } from "/scripts/ui.js";
-import { api } from "/scripts/api.js";
+import {app} from "/scripts/app.js";
+import {$el} from "/scripts/ui.js";
// Manage color palettes
@@ -24,6 +23,8 @@ const colorPalettes = {
"TAESD": "#DCC274", // cheesecake
},
"litegraph_base": {
+ "BACKGROUND_IMAGE": "",
+ "CLEAR_BACKGROUND_COLOR": "#222",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#FFF",
"NODE_TEXT_SIZE": 14,
@@ -55,7 +56,9 @@ const colorPalettes = {
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
- "border-color": "#4e4e4e"
+ "border-color": "#4e4e4e",
+ "tr-even-bg-color": "#222",
+ "tr-odd-bg-color": "#353535",
}
},
},
@@ -77,6 +80,8 @@ const colorPalettes = {
"VAE": "#FF7043", // deep orange
},
"litegraph_base": {
+ "BACKGROUND_IMAGE": "",
+ "CLEAR_BACKGROUND_COLOR": "lightgray",
"NODE_TITLE_COLOR": "#222",
"NODE_SELECTED_TITLE_COLOR": "#000",
"NODE_TEXT_SIZE": 14,
@@ -108,7 +113,9 @@ const colorPalettes = {
"descrip-text": "#444",
"drag-text": "#555",
"error-text": "#F44336",
- "border-color": "#888"
+ "border-color": "#888",
+ "tr-even-bg-color": "#f9f9f9",
+ "tr-odd-bg-color": "#fff",
}
},
},
@@ -162,7 +169,9 @@ const colorPalettes = {
"descrip-text": "#586e75", // Base01
"drag-text": "#839496", // Base0
"error-text": "#dc322f", // Solarized Red
- "border-color": "#657b83" // Base00
+ "border-color": "#657b83", // Base00
+ "tr-even-bg-color": "#002b36",
+ "tr-odd-bg-color": "#073642",
}
},
}
@@ -191,7 +200,7 @@ app.registerExtension({
const nodeData = defs[nodeId];
var inputs = nodeData["input"]["required"];
- if (nodeData["input"]["optional"] != undefined){
+ if (nodeData["input"]["optional"] !== undefined) {
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
}
@@ -211,7 +220,7 @@ app.registerExtension({
}
return types;
- };
+ }
function completeColorPalette(colorPalette) {
var types = getSlotTypes();
@@ -225,19 +234,16 @@ app.registerExtension({
colorPalette.colors.node_slot = sortObjectKeys(colorPalette.colors.node_slot);
return colorPalette;
- };
+ }
const getColorPaletteTemplate = async () => {
let colorPalette = {
"id": "my_color_palette_unique_id",
"name": "My Color Palette",
"colors": {
- "node_slot": {
- },
- "litegraph_base": {
- },
- "comfy_base": {
- }
+ "node_slot": {},
+ "litegraph_base": {},
+ "comfy_base": {}
}
};
@@ -266,32 +272,32 @@ app.registerExtension({
};
const addCustomColorPalette = async (colorPalette) => {
- if (typeof(colorPalette) !== "object") {
- app.ui.dialog.show("Invalid color palette");
+ if (typeof (colorPalette) !== "object") {
+ alert("Invalid color palette.");
return;
}
if (!colorPalette.id) {
- app.ui.dialog.show("Color palette missing id");
+ alert("Color palette missing id.");
return;
}
if (!colorPalette.name) {
- app.ui.dialog.show("Color palette missing name");
+ alert("Color palette missing name.");
return;
}
if (!colorPalette.colors) {
- app.ui.dialog.show("Color palette missing colors");
+ alert("Color palette missing colors.");
return;
}
- if (colorPalette.colors.node_slot && typeof(colorPalette.colors.node_slot) !== "object") {
- app.ui.dialog.show("Invalid color palette colors.node_slot");
+ if (colorPalette.colors.node_slot && typeof (colorPalette.colors.node_slot) !== "object") {
+ alert("Invalid color palette colors.node_slot.");
return;
}
- let customColorPalettes = getCustomColorPalettes();
+ const customColorPalettes = getCustomColorPalettes();
customColorPalettes[colorPalette.id] = colorPalette;
setCustomColorPalettes(customColorPalettes);
@@ -301,14 +307,18 @@ app.registerExtension({
}
}
- els.select.append($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: true }));
+ els.select.append($el("option", {
+ textContent: colorPalette.name + " (custom)",
+ value: "custom_" + colorPalette.id,
+ selected: true
+ }));
setColorPalette("custom_" + colorPalette.id);
await loadColorPalette(colorPalette);
};
const deleteCustomColorPalette = async (colorPaletteId) => {
- let customColorPalettes = getCustomColorPalettes();
+ const customColorPalettes = getCustomColorPalettes();
delete customColorPalettes[colorPaletteId];
setCustomColorPalettes(customColorPalettes);
@@ -350,7 +360,7 @@ app.registerExtension({
if (colorPalette.colors.comfy_base) {
const rootStyle = document.documentElement.style;
for (const key in colorPalette.colors.comfy_base) {
- rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]);
+ rootStyle.setProperty('--' + key, colorPalette.colors.comfy_base[key]);
}
}
app.canvas.draw(true, true);
@@ -380,11 +390,10 @@ app.registerExtension({
const fileInput = $el("input", {
type: "file",
accept: ".json",
- style: { display: "none" },
+ style: {display: "none"},
parent: document.body,
onchange: () => {
- let file = fileInput.files[0];
-
+ const file = fileInput.files[0];
if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
@@ -399,96 +408,116 @@ app.registerExtension({
id,
name: "Color Palette",
type: (name, setter, value) => {
- let options = [];
+ const options = [
+ ...Object.values(colorPalettes).map(c=> $el("option", {
+ textContent: c.name,
+ value: c.id,
+ selected: c.id === value
+ })),
+ ...Object.values(getCustomColorPalettes()).map(c=>$el("option", {
+ textContent: `${c.name} (custom)`,
+ value: `custom_${c.id}`,
+ selected: `custom_${c.id}` === value
+ })) ,
+ ];
- for (const c in colorPalettes) {
- const colorPalette = colorPalettes[c];
- options.push($el("option", { textContent: colorPalette.name, value: colorPalette.id, selected: colorPalette.id === value }));
- }
+ els.select = $el("select", {
+ style: {
+ marginBottom: "0.15rem",
+ width: "100%",
+ },
+ onchange: (e) => {
+ setter(e.target.value);
+ }
+ }, options)
- let customColorPalettes = getCustomColorPalettes();
- for (const c in customColorPalettes) {
- const colorPalette = customColorPalettes[c];
- options.push($el("option", { textContent: colorPalette.name + " (custom)", value: "custom_" + colorPalette.id, selected: "custom_" + colorPalette.id === value }));
- }
-
- return $el("div", [
- $el("label", { textContent: name || id }, [
- els.select = $el("select", {
- onchange: (e) => {
- setter(e.target.value);
- }
- }, options)
+ return $el("tr", [
+ $el("td", [
+ $el("label", {
+ for: id.replaceAll(".", "-"),
+ textContent: "Color palette:",
+ }),
]),
- $el("input", {
- type: "button",
- value: "Export",
- onclick: async () => {
- const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
- const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId));
- const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
- const blob = new Blob([json], { type: "application/json" });
- const url = URL.createObjectURL(blob);
- const a = $el("a", {
- href: url,
- download: colorPaletteId + ".json",
- style: { display: "none" },
- parent: document.body,
- });
- a.click();
- setTimeout(function () {
- a.remove();
- window.URL.revokeObjectURL(url);
- }, 0);
- },
- }),
- $el("input", {
- type: "button",
- value: "Import",
- onclick: () => {
- fileInput.click();
- }
- }),
- $el("input", {
- type: "button",
- value: "Template",
- onclick: async () => {
- const colorPalette = await getColorPaletteTemplate();
- const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
- const blob = new Blob([json], { type: "application/json" });
- const url = URL.createObjectURL(blob);
- const a = $el("a", {
- href: url,
- download: "color_palette.json",
- style: { display: "none" },
- parent: document.body,
- });
- a.click();
- setTimeout(function () {
- a.remove();
- window.URL.revokeObjectURL(url);
- }, 0);
- }
- }),
- $el("input", {
- type: "button",
- value: "Delete",
- onclick: async () => {
- let colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
+ $el("td", [
+ els.select,
+ $el("div", {
+ style: {
+ display: "grid",
+ gap: "4px",
+ gridAutoFlow: "column",
+ },
+ }, [
+ $el("input", {
+ type: "button",
+ value: "Export",
+ onclick: async () => {
+ const colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
+ const colorPalette = await completeColorPalette(getColorPalette(colorPaletteId));
+ const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
+ const blob = new Blob([json], {type: "application/json"});
+ const url = URL.createObjectURL(blob);
+ const a = $el("a", {
+ href: url,
+ download: colorPaletteId + ".json",
+ style: {display: "none"},
+ parent: document.body,
+ });
+ a.click();
+ setTimeout(function () {
+ a.remove();
+ window.URL.revokeObjectURL(url);
+ }, 0);
+ },
+ }),
+ $el("input", {
+ type: "button",
+ value: "Import",
+ onclick: () => {
+ fileInput.click();
+ }
+ }),
+ $el("input", {
+ type: "button",
+ value: "Template",
+ onclick: async () => {
+ const colorPalette = await getColorPaletteTemplate();
+ const json = JSON.stringify(colorPalette, null, 2); // convert the data to a JSON string
+ const blob = new Blob([json], {type: "application/json"});
+ const url = URL.createObjectURL(blob);
+ const a = $el("a", {
+ href: url,
+ download: "color_palette.json",
+ style: {display: "none"},
+ parent: document.body,
+ });
+ a.click();
+ setTimeout(function () {
+ a.remove();
+ window.URL.revokeObjectURL(url);
+ }, 0);
+ }
+ }),
+ $el("input", {
+ type: "button",
+ value: "Delete",
+ onclick: async () => {
+ let colorPaletteId = app.ui.settings.getSettingValue(id, defaultColorPaletteId);
- if (colorPalettes[colorPaletteId]) {
- app.ui.dialog.show("You cannot delete built-in color palette");
- return;
- }
+ if (colorPalettes[colorPaletteId]) {
+ alert("You cannot delete a built-in color palette.");
+ return;
+ }
- if (colorPaletteId.startsWith("custom_")) {
- colorPaletteId = colorPaletteId.substr(7);
- }
+ if (colorPaletteId.startsWith("custom_")) {
+ colorPaletteId = colorPaletteId.substr(7);
+ }
- await deleteCustomColorPalette(colorPaletteId);
- }
- }),
- ]);
+ await deleteCustomColorPalette(colorPaletteId);
+ }
+ }),
+ ]),
+ ]),
+ ])
},
defaultValue: defaultColorPaletteId,
async onChange(value) {
@@ -496,15 +525,25 @@ app.registerExtension({
return;
}
- if (colorPalettes[value]) {
- await loadColorPalette(colorPalettes[value]);
+ let palette = colorPalettes[value];
+ if (palette) {
+ await loadColorPalette(palette);
} else if (value.startsWith("custom_")) {
value = value.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[value]) {
+ palette = customColorPalettes[value];
await loadColorPalette(customColorPalettes[value]);
}
}
+
+ let {BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR} = palette.colors.litegraph_base;
+ if (BACKGROUND_IMAGE === undefined || CLEAR_BACKGROUND_COLOR === undefined) {
+ const base = colorPalettes["dark"].colors.litegraph_base;
+ BACKGROUND_IMAGE = base.BACKGROUND_IMAGE;
+ CLEAR_BACKGROUND_COLOR = base.CLEAR_BACKGROUND_COLOR;
+ }
+ app.canvas.updateBackground(BACKGROUND_IMAGE, CLEAR_BACKGROUND_COLOR);
},
});
},
diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js
index 51e66f924..662d87e74 100644
--- a/web/extensions/core/contextMenuFilter.js
+++ b/web/extensions/core/contextMenuFilter.js
@@ -1,132 +1,138 @@
-import { app } from "/scripts/app.js";
+import {app} from "/scripts/app.js";
// Adds filtering to combo context menus
-const id = "Comfy.ContextMenuFilter";
-app.registerExtension({
- name: id,
+const ext = {
+ name: "Comfy.ContextMenuFilter",
init() {
const ctxMenu = LiteGraph.ContextMenu;
+
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
- Object.assign(filter.style, {
- width: "calc(100% - 10px)",
- border: "0",
- boxSizing: "border-box",
- background: "#333",
- border: "1px solid #999",
- margin: "0 0 5px 5px",
- color: "#fff",
- });
+ filter.classList.add("comfy-context-menu-filter");
filter.placeholder = "Filter list";
this.root.prepend(filter);
- let selectedIndex = 0;
- let items = this.root.querySelectorAll(".litemenu-entry");
- let itemCount = items.length;
- let selectedItem;
+ const items = Array.from(this.root.querySelectorAll(".litemenu-entry"));
+ let displayedItems = [...items];
+ let itemCount = displayedItems.length;
- // Apply highlighting to the selected item
- function updateSelected() {
- if (selectedItem) {
- selectedItem.style.setProperty("background-color", "");
- selectedItem.style.setProperty("color", "");
- }
- selectedItem = items[selectedIndex];
- if (selectedItem) {
- selectedItem.style.setProperty("background-color", "#ccc", "important");
- selectedItem.style.setProperty("color", "#000", "important");
- }
- }
+ // We must request an animation frame for the current node of the active canvas to update.
+ requestAnimationFrame(() => {
+ const currentNode = LGraphCanvas.active_canvas.current_node;
+ const clickedComboValue = currentNode.widgets
+ .filter(w => w.type === "combo" && w.options.values.length === values.length)
+ .find(w => w.options.values.every((v, i) => v === values[i]))
+ .value;
- const positionList = () => {
- const rect = this.root.getBoundingClientRect();
-
- // If the top is off screen then shift the element with scaling applied
- if (rect.top < 0) {
- const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
- const shift = (this.root.clientHeight * scale) / 2;
- this.root.style.top = -shift + "px";
- }
- }
-
- updateSelected();
-
- // Arrow up/down to select items
- filter.addEventListener("keydown", (e) => {
- if (e.key === "ArrowUp") {
- if (selectedIndex === 0) {
- selectedIndex = itemCount - 1;
- } else {
- selectedIndex--;
- }
- updateSelected();
- e.preventDefault();
- } else if (e.key === "ArrowDown") {
- if (selectedIndex === itemCount - 1) {
- selectedIndex = 0;
- } else {
- selectedIndex++;
- }
- updateSelected();
- e.preventDefault();
- } else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
- selectedItem.click();
- } else if(e.key === "Escape") {
- this.close();
- }
- });
-
- filter.addEventListener("input", () => {
- // Hide all items that dont match our filter
- const term = filter.value.toLocaleLowerCase();
- items = this.root.querySelectorAll(".litemenu-entry");
- // When filtering recompute which items are visible for arrow up/down
- // Try and maintain selection
- let visibleItems = [];
- for (const item of items) {
- const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
- if (visible) {
- item.style.display = "block";
- if (item === selectedItem) {
- selectedIndex = visibleItems.length;
- }
- visibleItems.push(item);
- } else {
- item.style.display = "none";
- if (item === selectedItem) {
- selectedIndex = 0;
- }
- }
- }
- items = visibleItems;
+ let selectedIndex = values.findIndex(v => v === clickedComboValue);
+ let selectedItem = displayedItems?.[selectedIndex];
updateSelected();
- // If we have an event then we can try and position the list under the source
- if (options.event) {
- let top = options.event.clientY - 10;
-
- const bodyRect = document.body.getBoundingClientRect();
- const rootRect = this.root.getBoundingClientRect();
- if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
- top = Math.max(0, bodyRect.height - rootRect.height - 10);
- }
-
- this.root.style.top = top + "px";
- positionList();
+ // Apply highlighting to the selected item
+ function updateSelected() {
+ selectedItem?.style.setProperty("background-color", "");
+ selectedItem?.style.setProperty("color", "");
+ selectedItem = displayedItems[selectedIndex];
+ selectedItem?.style.setProperty("background-color", "#ccc", "important");
+ selectedItem?.style.setProperty("color", "#000", "important");
}
- });
- requestAnimationFrame(() => {
- // Focus the filter box when opening
- filter.focus();
+ const positionList = () => {
+ const rect = this.root.getBoundingClientRect();
- positionList();
- });
+ // If the top is off-screen then shift the element with scaling applied
+ if (rect.top < 0) {
+ const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
+ const shift = (this.root.clientHeight * scale) / 2;
+ this.root.style.top = -shift + "px";
+ }
+ }
+
+ // Arrow up/down to select items
+ filter.addEventListener("keydown", (event) => {
+ switch (event.key) {
+ case "ArrowUp":
+ event.preventDefault();
+ if (selectedIndex === 0) {
+ selectedIndex = itemCount - 1;
+ } else {
+ selectedIndex--;
+ }
+ updateSelected();
+ break;
+ case "ArrowRight":
+ event.preventDefault();
+ selectedIndex = itemCount - 1;
+ updateSelected();
+ break;
+ case "ArrowDown":
+ event.preventDefault();
+ if (selectedIndex === itemCount - 1) {
+ selectedIndex = 0;
+ } else {
+ selectedIndex++;
+ }
+ updateSelected();
+ break;
+ case "ArrowLeft":
+ event.preventDefault();
+ selectedIndex = 0;
+ updateSelected();
+ break;
+ case "Enter":
+ selectedItem?.click();
+ break;
+ case "Escape":
+ this.close();
+ break;
+ }
+ });
+
+ filter.addEventListener("input", () => {
+ // Hide all items that don't match our filter
+ const term = filter.value.toLocaleLowerCase();
+ // When filtering, recompute which items are visible for arrow up/down and maintain selection.
+ displayedItems = items.filter(item => {
+ const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term);
+ item.style.display = isVisible ? "block" : "none";
+ return isVisible;
+ });
+
+ selectedIndex = 0;
+ if (displayedItems.includes(selectedItem)) {
+ selectedIndex = displayedItems.findIndex(d => d === selectedItem);
+ }
+ itemCount = displayedItems.length;
+
+ updateSelected();
+
+ // If we have an event then we can try and position the list under the source
+ if (options.event) {
+ let top = options.event.clientY - 10;
+
+ const bodyRect = document.body.getBoundingClientRect();
+ const rootRect = this.root.getBoundingClientRect();
+ if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
+ top = Math.max(0, bodyRect.height - rootRect.height - 10);
+ }
+
+ this.root.style.top = top + "px";
+ positionList();
+ }
+ });
+
+ requestAnimationFrame(() => {
+ // Focus the filter box when opening
+ filter.focus();
+
+ positionList();
+ });
+ })
}
return ctx;
@@ -134,4 +140,6 @@ app.registerExtension({
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
-});
+}
+
+app.registerExtension(ext);
diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js
index 9401678b0..5b8304711 100644
--- a/web/extensions/core/slotDefaults.js
+++ b/web/extensions/core/slotDefaults.js
@@ -10,7 +10,7 @@ app.registerExtension({
LiteGraph.middle_click_slot_add_default_node = true;
this.suggestionsNumber = app.ui.settings.addSetting({
id: "Comfy.NodeSuggestions.number",
- name: "number of nodes suggestions",
+ name: "Number of nodes suggestions",
type: "slider",
attrs: {
min: 1,
diff --git a/web/index.html b/web/index.html
index da0adb6c2..c48d716e1 100644
--- a/web/index.html
+++ b/web/index.html
@@ -7,6 +7,7 @@
+